不同的公司或者不同的公司项目,使用的数据库很可能也不同,这就需要增加对不同数据库的支持。一个测试项目的测试用例一般都非常多,每个测试用例也几乎都需要操作数据库,如果每个测试用例都需要创建数据库连接然后断开,就太消耗系统性能了。我们可以使用连接池进行数据库连接管理。

我目前需要接触的数据库是 mysql 和 redis,所以本篇以 mysql 和 redis 为内容,介绍我是怎么增加对数据库的支持的。

确定配置文件和设计模式


数据库连接信息肯定不能写在代码里,需要写在配置文件中,配置文件类我推荐 toml 格式

[mysql]
host = "192.168.9.118"
user = "root"
port = 3306
password = "Passw0rd@ylzl"
database = "xgate"
charset = "utf8mb4"
mincached = 1
maxcached = 4
maxconnections = 10
blocking = true

[redis]
host = "192.168.9.118"
port = 6379
password = "Passw0rd@ylzl"
db = "0"

max_connections = 4

实现数据获取函数

python 中自带处理 toml 文件的 tomllib 库,因为我现在只使用 toml 文件,所以只写了一个 load_toml 函数来处理。如果读者决定支持其他格式文件,或者决定未来扩展对其他配置文件的支持,读者可以按照合适的设计模式先设计好框架。

from pathlib import Path
import tomllib

def load_toml(name: str) -> dict:
    """
    加载指定名称的TOML配置文件并解析为字典。

    参数:
    name (str): 配置文件的名称(不包含.toml后缀),函数会自动拼接后缀并在项目根目录的config子目录中查找。

    返回:
    dict: 解析后的TOML配置数据,以字典形式返回,键值对应配置文件中的内容。

    异常处理:
    - FileNotFoundError: 当指定的配置文件在config目录中不存在时引发,包含具体的文件路径信息。
    - tomllib.TOMLDecodeError: 当配置文件存在但TOML格式错误(如语法错误、结构异常)导致无法解析时引发,包含错误详情。
    - RuntimeError: 当发生其他未知错误(如文件读取权限问题、意外异常)导致加载失败时引发,包含错误描述。
    """
    root_directory = Path(__file__).resolve().parent.parent
    name = name + ".toml"
    
    # 定义文件路径
    config_path = root_directory / "config" / name
    data_path = root_directory / "data" / name
    # 判断文件是否存在,如果都不存在,赋值为None
    toml_file_path = config_path if config_path.is_file() else (data_path if data_path.is_file() else None)
    
    try:
        with open(str(toml_file_path), "rb") as f:
            data = tomllib.load(f)

        return data
    except FileNotFoundError:
        raise FileNotFoundError(f"配置文件不存在: {toml_file_path}")
    except tomllib.TOMLDecodeError as e:
        raise ValueError(f"配置文件格式错误: {str(e)}")
    except Exception as e:
        raise RuntimeError(f"加载配置失败: {str(e)}")
        

# 用于测试
def main() -> None:
    print(load_toml("data"))


if __name__ == '__main__':
    main()

设计框架

解决获取配置信息的问题后,我们需要思考如何在项目中支持不同的数据库。我选择的设计模式是工厂策略模式,核心是基于连接池。

from abc import ABC, abstractmethod

from common.get_data import load_toml


# ====================== 1. 抽象策略接口 ======================
class DatabaseConnectionPool(ABC):
    def __init__(self, database_configuration_name: str) -> None:
        self._database_configuration_name = database_configuration_name
        self._con_data = self._get_data_from_config()
        self._create_pool()

    @property
    @abstractmethod
    def database_con_info(self) -> list:
        """
        定义数据库的必要连接信息。如:["host", "port", "user", "password", "database", "charset"]
        """
        pass

    def _get_data_from_config(self) -> dict:
        """
        从TOML配置文件加载并校验数据库连接配置,分离基础连接配置与连接池配置。

        该函数负责读取指定的TOML配置文件,定位目标数据库配置项,校验配置的完整性(确保包含所有必要参数),
        并将连接池配置(可选)与基础数据库连接配置分离,最终返回结构化的配置数据,为数据库连接池初始化提供支持。

        参数:
            无显式入参,依赖类实例属性:
                - self._database_configuration_name: 目标数据库配置项的名称(需在配置文件中存在)
                - self.database_con_info: 基础数据库连接的必要参数列表(如host、user、password、db等)

        返回:
            dict: 字典,包含基础数据库连接配置与连接池配置(可选)

        异常处理:
            - RuntimeError: 加载TOML配置文件失败时触发(如文件不存在、格式错误、权限不足等)
            - ValueError: 配置校验失败时触发,包含两种场景:
                - 配置文件中未找到与self._database_configuration_name匹配的配置项
                - 目标配置项中缺少self.database_con_info定义的必要参数
        """
        try:
            database_con_data = load_toml("database")
        except Exception as e:
            raise RuntimeError(f"加载数据库配置文件失败: {str(e)}") from e

        if self._database_configuration_name not in database_con_data:
            raise ValueError(
                f"配置文件中未找到名称为 '{self._database_configuration_name}' 的数据库配置项,"
                f"可用配置项:{list(database_con_data.keys())}"
            )

        con_data = database_con_data[self._database_configuration_name]

        for key in self.database_con_info:
            if key not in con_data:
                raise ValueError(
                    f"数据库配置项 '{self._database_configuration_name}' 缺少必要参数: '{key}',"
                    f"必需参数列表:{list(self.database_con_info)}"
                )

        return con_data

    @abstractmethod
    def _create_pool(self):
        """创建数据库连接池"""
        pass

    @abstractmethod
    def close_pool(self):
        """关闭数据库连接池"""
        pass
    

# ====================== 2. 具体策略类,待实现======================




# ====================== 3. 策略工厂类======================
class DatabaseConnectionPoolFactory:
    """数据库连接池工厂:封装连接池对象的创建"""
    # 策略映射表:键为数据库类型,值为对应的连接池类
    _STRATEGY_MAP = {
        "具体数据库类型1": 具体数据库连接池类1,
        "具体数据库类型2": 具体数据库连接池类2
    }

    # ========== 重载签名1:匹配 具体数据库类型1 类型,返回 具体数据库连接池类1 ==========
    @staticmethod
    @overload
    def get_strategy(
        database_type: Literal["具体数据库类型1"],
        database_configuration_name: str
    ) -> 具体数据库连接池类1:
        ...

    # ========== 重载签名2:匹配 具体数据库类型2 类型,返回 具体数据库连接池类2 ==========
    @staticmethod
    @overload
    def get_strategy(
        database_type: Literal["具体数据库类型2"],
        database_configuration_name: str
    ) -> 具体数据库连接池类2:
        ...

    # ========== 实现签名:兼容所有情况 ==========
    @staticmethod
    def get_strategy(
        database_type: Literal["具体数据库类型1", "具体数据库类型2"] | str,  # 覆盖有效类型+任意字符串
        database_configuration_name: str
    ) -> Union[具体数据库连接池类1,具体数据库连接池类2]:
        """根据数据库类型获取对应的数据库连接池实例
        :param database_type: 数据库类型(比如 mysql/redis)
        :param database_configuration_name: database.toml 中的配置项
        :return: 数据库连接池实例
        :raises ValueError: 不支持的数据库类型
        """
        strategy_class = DatabaseConnectionPoolFactory._STRATEGY_MAP.get(database_type)
        if not strategy_class:
            raise ValueError(f"不支持的数据库类型:{database_type}")
        # 创建并返回连接池实例(传入配置名称)
        return strategy_class(database_configuration_name)

框架规定,数据库连接信息只能定义在 config/ 下的 database.toml 中。

关于 ABC 的相关内容可以查看我的另一篇文章抽象类

如果读者不了解什么是工厂策略模式,可以浏览器搜索设计模式进而了解。

@overloa 也是一个有意思的装饰器,读者可以搜索了解,后续我会写文章介绍它。

支持 mysql


我使用的是 pymysql 库,由于 pymysql 库不支持连接池,又使用了 dbutils 库。DBUtils 是一个为 Python 的 DB-API 2.0 接口 提供数据库连接池的第三方库,pymysql 就是由 DB-API 2.0 驱动。

安装 pymysql 库和 dbutils 库

pip install pymysql dbutils

实现逻辑

from abc import ABC, abstractmethod
from typing import Generator, Optional, Union, overload, Literal

import pymysql
from dbutils.pooled_db import PooledDB

from common.get_data import load_toml


class DatabaseConnectionPool(ABC):
   ...


class MysqlConnectionPool(DatabaseConnectionPool):
    """
    MySQL 数据库连接池封装类,基于 `dbutils.pooled_db.PooledDB` 实现。
    提供自动加载配置文件、创建连接池及异常处理功能,简化 MySQL 数据库操作流程。

    核心功能:
        - 自动加载 `database.toml` 中指定名称的 MySQL 配置, 创建 MySQL 连接池,支持自定义连接池参数(如最小/最大连接数等)
        - 提供 select_database、select_large_database 方法用于查询数据库,selectchange_database 方法用于变更(插入、修改、删除)数据库,close_pool 方法用于关闭连接池

    参数:
        - database_configuration_name (str): 数据库连接配置名称(对应 `database.toml` 中的配置项)。

    异常处理:
        - ConnectionError: 触发场景为 MySQL 连接失败(如主机不可达、端口错误、账号密码错误等),
        异常信息包含底层 `pymysql.MySQLError` 详情和配置名称;
        - RuntimeError: 触发场景为连接池初始化过程中的其他异常(如配置解析失败、`PooledDB` 调用异常等),
        异常信息包含底层错误详情和配置名称;
        - 所有异常均通过 `raise ... from e` 保留原始异常链路,便于排查根因。

    最佳实践建议:
        - 配置文件准备:
            - 确保 `database.toml` 放置在项目根目录的 `config` 子目录下(符合 `load_toml` 函数的路径约定);
            - 配置文件中,每个数据库配置项需以表的形式定义(如 `[mysql]` 或 `[mysql_dev]`),且包含 MySQL 所需的所有必要参数(如 "host"、"port"、"user"、"password"、"database" 等);
            - 避免在配置文件中遗漏必要参数,否则会触发参数缺失异常。请对照 database.py 模块中 `self.database_con_info` 定义进行检查。
            - 连接池参数如 `mincached`、`maxcached`、`maxconnections` 等,可选配置如:

            ```toml
            [mysql]
            host = "localhost"
            port = 3306
            user = "root"
            password = "password"
            database = "test_db"
            charset = "utf8mb4"
            mincached = 2
            maxcached = 5
            maxconnections = 10
            blocking = true
            ```

        - 实例化建议:
            - 入口模块中根据需要,实例化创建连接池对象:

            ```python
            mysql_connection = MySQLDataBase()
            mysql_dev_connection = MySQLDataBase("mysql_dev")
            ```
            - 入口模块中最后关闭连接池对象:

            ```python
            mysql_connection.close_pool()
            mysql_dev_connection.close_pool()
            ```

        - 连接池参数调优:
            - mincached:根据业务最小并发量设置,避免频繁创建新连接;
            - maxconnections:根据数据库服务器最大连接数和业务峰值并发量设置,避免超过数据库限制;
            - blocking:建议设为 true(默认),避免峰值时因无可用连接直接报错,而是阻塞等待直到有连接释放;
            - 若业务并发量较大,可适当增大 maxcached 和 maxconnections,但需注意数据库服务器的资源限制。
    """

    @property
    def database_con_info(self) -> list:
        """定义Mysql数据库的必要连接信息"""
        return ["host", "user", "password", "database", "charset"]

    def _create_pool(self) -> None:
        try:
            self._pool = PooledDB(creator=pymysql, **self._con_data)
        except pymysql.MySQLError as e:
            raise ConnectionError(
                f"MySQL连接失败: {str(e)},配置名称: {self._database_configuration_name}"
            ) from e
        except Exception as e:
            raise RuntimeError(
                f"MySQL连接池初始化失败: {str(e)},配置名称: {self._database_configuration_name}"
            ) from e

    def close_pool(self) -> None:
        """关闭连接池释放资源"""
        if hasattr(self, "_pool") and self._pool:
            self._pool.close()

    def __del__(self) -> None:
        """
        析构函数:在对象被销毁时自动关闭连接池
        """
        self.close_pool()

    def select_database(
        self,
        sql: str,
        params: Optional[tuple] = None,
    ) -> list[dict]:
        """
        通过数据库连接池执行SELECT查询,并返回字典格式的结果集。

        该函数从预先初始化的连接池中获取一个连接,使用参数化查询执行SQL语句,
        以防止SQL注入风险。查询结果会被转换为字典列表,其中每个字典的键为数据库表的列名,
        值为对应的记录值。查询结束后,连接和游标会被`with`语句自动管理和释放。

        参数:
            - sql (str): 待执行的SELECT查询语句。为防止SQL注入,必须使用参数化查询,
                即使用`%s`作为占位符。
                - 示例: "SELECT id, name FROM users WHERE age > %s AND status = %s"

            - params (Optional[tuple]): 可选参数,用于替换SQL语句中的`%s`占位符。
                参数的数量和类型必须与SQL语句中的占位符完全匹配。
                - 当SQL语句中无占位符时,可传入`None`或不传入此参数。
                - 示例: 若SQL为"SELECT * FROM users WHERE id = %s", 则params可为`(101,)`。
                - 注意: 即使只有一个参数,也必须使用元组形式,如`(value,)`,不能是`(value)`。

        返回:
            - List[Dict]: 查询结果的列表。列表中的每个元素都是一个字典,
                - 字典的键是查询结果集中的列名(如'id', 'name'),值是对应的数据。
                - 如果查询没有返回任何结果,将返回一个空列表`[]`。

        异常处理:
            - pymysql.MySQLError: 当SQL执行失败时触发,例如SQL语法错误、表不存在、
                权限不足等。异常信息会包含具体的错误描述、执行的SQL语句以及传入的参数,
                便于快速定位问题。
            - Exception: 捕获其他所有未预期的异常,确保程序稳定性并提供清晰的错误信息。

        最佳实践建议:
            - 参数化查询: 始终使用`params`参数传递动态值,严禁直接将用户输入或变量拼接进SQL字符串。
                - 正确示例: sql="SELECT * FROM users WHERE name = %s", params=(username,)
                - 错误示例: sql=f"SELECT * FROM users WHERE name = '{username}'" (存在注入风险)
            - 参数类型: `params`必须是一个元组。即使只有一个参数,也需要写成元组形式`(value,)`。
            - SQL格式化: 为了提高可读性,复杂的SQL语句建议使用多行字符串或外部文件存储。
            - 结果处理: 函数返回的是完整的结果集。如果预期结果集非常大,一次性加载可能会消耗大量内存。
                这种情况下,建议使用`cursor.fetchone()`或`cursor.fetchmany(size)`进行分批处理。
            - 异常捕获: 在生产环境中调用此函数时,强烈建议使用 `try...except` 块来捕获并处理上述可能的异常,
                例如,向用户返回更友好的错误提示或记录日志。
            - 连接池配置: 确保数据库连接池的配置(如最大连接数、空闲连接超时等)适合你的应用场景,
                以避免连接耗尽或资源浪费。
        """
        try:
            with self._pool.connection() as conn:
                with conn.cursor(cursor=pymysql.cursors.DictCursor) as cursor:
                    cursor.execute(sql, params)
                    results = cursor.fetchall()
                    return results

        except pymysql.MySQLError as e:
            raise pymysql.MySQLError(
                f"数据库查询失败: {e}, 执行的SQL: {sql}, 使用的参数: {params}"
            ) from e
        except Exception as e:
            raise Exception(f"执行查询时发生未知错误: {e}") from e

    def select_large_database(
        self, sql: str, params: Optional[tuple] = None, batch_size: int = 1000
    ) -> Generator[list[dict], None, None]:
        """
        分批执行SELECT查询,适用于处理大型结果集。

        通过生成器(Generator)逐批返回查询结果,避免一次性将所有数据加载到内存。

        参数:
            - sql (str): 待执行的SELECT查询语句。
            - params (Optional[tuple[Any, ...]]): 用于替换SQL占位符的参数。
            - batch_size (int): 每批获取的数据行数。默认值为1000,可根据内存和性能需求调整。

        返回:
            - Generator[List[Dict[str, Any]], None, None]: 一个生成器,每次迭代产出一批数据(字典列表)。

        异常处理:
            - pymysql.MySQLError: 数据库查询相关错误。
            - Exception: 其他未知错误。

        最佳实践:
            ```python
            db_pool = DatabasePool()
            sql = "SELECT id, name FROM very_large_table WHERE created_at > %s"
            params = ("2023-01-01",)
            for batch in db_pool.select_large_dataset(sql, params, batch_size=500):
                process_batch(batch) # 你自己的批量处理逻辑
            ```
        """
        conn = None
        cursor = None
        try:
            conn = self._pool.connection()
            with conn.cursor(cursor=pymysql.cursors.DictCursor) as cursor:
                cursor.execute(sql, params)

                while True:
                    batch = cursor.fetchmany(batch_size)
                    if not batch:
                        break
                    yield batch

        except pymysql.MySQLError as e:
            raise pymysql.MySQLError(
                f"分批查询数据库失败: {e}, SQL: {sql}, 参数: {params}"
            ) from e
        except Exception as e:
            raise Exception(f"执行分批查询时发生未知错误: {e}") from e
        finally:
            if conn:
                conn.close()

    def change_database(
        self,
        sql: str,
        params: Union[None, tuple, list[tuple]] = None,
        batch_size: int = 1000,
    ) -> int:
        """
        执行数据库变更操作(INSERT, UPDATE, DELETE),并支持单条或批量处理。

        该函数从连接池获取连接,根据参数`params`的类型,选择执行单条SQL语句或分批执行多条SQL语句。
        批量处理时,会按照`batch_size`指定的大小拆分参数列表,以避免单次操作数据量过大导致的性能问题或数据库限制。
        操作成功后会自动提交事务(`commit`),若过程中发生任何错误,则会回滚事务(`rollback`)并抛出异常。

        参数:
            - sql (str): 待执行的SQL变更语句。
                - 单条操作时,使用`%s`作为参数占位符,例如: "INSERT INTO users (name, age) VALUES (%s, %s)"。
                - 批量操作时,SQL语句格式与单条一致,`executemany`会自动为列表中的每个元组执行一次。
            - params (Union[None, Tuple[Any, ...], List[Tuple[Any, ...]]]): 可选参数,用于替换SQL中的占位符。
                - `None`: 表示SQL语句中无参数,直接执行。
                - `Tuple[Any, ...]`: 单条操作的参数,例如: ("Alice", 30)。
                - `List[Tuple[Any, ...]]`: 批量操作的参数列表,每个元组对应一条SQL语句的参数,例如: [("Alice", 30), ("Bob", 25)]。
            - batch_size (int): 批量处理时的每批参数数量,默认值为1000。
                - 需根据数据库性能、网络带宽及SQL语句复杂度调整,避免过大或过小。

        返回:
            - int: 本次操作受影响的总行数。

        异常处理:
            - TypeError: 若`params`的类型不是None、元组或元组列表,抛出类型错误。
            - MySQLError: 数据库执行失败(如语法错误、约束冲突、权限不足等),回滚事务并抛出异常,包含错误信息、执行的SQL及参数。
            - Exception: 捕获其他未预期的异常,确保程序稳定性。

        最佳实践建议:
            - 参数化查询: 始终使用`params`传递动态值,严禁直接拼接SQL字符串,防止SQL注入。
            - 批量操作优化:
                - 批量插入/更新时,优先使用该函数的批量处理功能(`params`为列表),而非循环调用单条操作,减少数据库连接次数和网络开销。
                - 根据数据量调整`batch_size`:数据量小时可适当减小,数据量大时可增大(如5000-10000),但需避免超过数据库单次处理限制(如MySQL的`max_allowed_packet`)。
            - 事务管理: 该函数已内置事务提交/回滚逻辑,无需在外部手动管理事务。
            - 异常处理: 调用时建议使用`try-except`捕获异常,记录详细日志(如错误信息、SQL、参数),便于问题排查。
            - 数据校验: 批量处理前,建议校验`params`列表中每个元组的长度与SQL语句中占位符数量一致,避免因参数不匹配导致执行失败。
        """
        conn = None
        try:
            with self._pool.connection() as conn:
                with conn.cursor() as cursor:
                    affected = 0

                    if params is None:
                        affected = cursor.execute(sql)

                    elif isinstance(params, tuple):
                        affected = cursor.execute(sql, params)

                    elif isinstance(params, list):
                        if not params:
                            return 0
                        for i in range(0, len(params), batch_size):
                            batch_params = params[i : i + batch_size]
                            batch_affected = cursor.executemany(sql, batch_params)
                            affected += batch_affected

                    else:
                        raise TypeError(
                            f"参数params类型错误,需为None、元组或元组列表,当前类型: {type(params)}"
                        )

                    conn.commit()
                    return affected

        except pymysql.MySQLError as e:
            if conn and conn.open:
                try:
                    conn.rollback()
                except Exception as rollback_e:
                    raise Exception(f"事务回滚失败: {str(rollback_e)}") from rollback_e
            raise pymysql.MySQLError(
                f"数据库变更失败: {str(e)},执行SQL: {sql},参数: {params}"
            ) from e

        except TypeError as e:
            raise e

        except Exception as e:
            raise Exception(f"执行数据库变更时发生未知错误: {str(e)}") from e
        
class DatabaseConnectionPoolFactory:
    """数据库连接池工厂:封装连接池对象的创建"""
    # 策略映射表:键为数据库类型,值为对应的连接池类
    _STRATEGY_MAP = {
        "mysql": MysqlConnectionPool
    }

    @staticmethod
    @overload
    def get_strategy(
        database_type: Literal["mysql"],
        database_configuration_name: str
    ) -> MysqlConnectionPool:
        ...

    @staticmethod
    def get_strategy(
        database_type: Literal["mysql", "redis"] | str,
        database_configuration_name: str
    ) -> MysqlConnectionPool:
        """根据数据库类型获取对应的数据库连接池实例
        :param database_type: 数据库类型(比如 mysql)
        :param database_configuration_name: database.toml 中的配置项
        :return: 数据库连接池实例
        :raises ValueError: 不支持的数据库类型
        """
        strategy_class = DatabaseConnectionPoolFactory._STRATEGY_MAP.get(database_type)
        if not strategy_class:
            raise ValueError(f"不支持的数据库类型:{database_type}")
        # 创建并返回连接池实例(传入配置名称)
        return strategy_class(database_configuration_name)


def main() -> None:
    
    """测试"""
    mysql_connection = DatabaseConnectionPoolFactory.get_strategy(database_type="mysql", database_configuration_name="mysql")
    sql = "SELECT * FROM user "
    print(mysql_connection.select_database(sql))
    mysql_connection.close_pool()


if __name__ == "__main__":
    main()

上述代码中,通过 dbutils 库为 pymysql 增加了连接池管理,并且封装了 select_database、select_large_database 方法用于查询数据库,selectchange_database 方法用于变更(插入、修改、删除)数据库。

关于 pymysql 的相关内容可以查看我的文章:远程访问 Mysql 数据库,并对数据进行简单处理

支持 redis


操作 Redis 数据库,Python 官方推荐客户端是 redis-py(确实推荐),其不仅完全支持 Redis 的所有核心功能,而且自带连接池管理。

安装 redis-py 库

pip install redis

实现逻辑

from abc import ABC, abstractmethod
from typing import Generator, Optional, Union, overload, Literal

import pymysql
import redis
from dbutils.pooled_db import PooledDB

from common.get_data import load_toml


class DatabaseConnectionPool(ABC):
    ...


class MysqlConnectionPool(DatabaseConnectionPool):
    ...


class RedisConnectionPool(DatabaseConnectionPool):
    """
    Redis 数据库连接池封装类,基于 `redis.ConnectionPool` 实现
    提供自动加载配置文件、创建连接池及异常处理功能,简化 Redis 数据库操作流程。
    核心功能:
        - 自动加载 `database.toml` 中指定名称的 Redis 配置, 创建 Redis 连接池。
        - 提供 close_pool 方法用于关闭连接池。
        - 可使用 redis 的所有功能,通过传入连接池实例化 redis.Redis 对象进行操作。
    参数:
        - database_configuration_name (str): 数据库连接信息配置名称(对应 `database.toml` 中的配置项)。
    异常处理:
        - RuntimeError: 触发场景为连接池初始化过程中的异常(如配置解析失败、`ConnectionPool` 调用异常等),异常信息包含底层错误详情和配置名称;
        - 所有异常均通过 `raise ... from e` 保留原始异常链路,便于排查根因。
    最佳实践建议:
        - 配置文件准备:
            - 确保 `database.toml` 放置在项目根目录的 `config` 子目录下(符合 `load_toml` 函数的路径约定);
            - 配置文件中,每个数据库配置项需以表的形式定义(如 `[redis]` 或 `[redis_dev]`),且包含 Redis 所需的所有必要参数(如 "host"、"port"、"password"、"db" 等);
            - 避免在配置文件中遗漏必要参数,否则会触发参数缺失异常。请对照 database.py 模块中 `self.database_con_info` 定义进行检查。
            - 示例配置格式:
            ```toml
            [redis]
            host = "localhost"
            port = 6379
            password = "yourpassword"
            db = 0
            max_connections = 4
            ```
        - 实例化建议:
            - 入口模块中根据需要,实例化创建连接池对象:
            ```python
            redis_connection = RedisConnectionPool("redis")
            redis_dev_connection = RedisConnectionPool("redis_dev")
            ```
            - 入口模块中最后关闭连接池对象:
            ```python
            redis_connection.close_pool()
            redis_dev_connection.close_pool()
            ```
    """

    @property
    def database_con_info(self) -> list:
        """定义Redis数据库的必要连接信息"""
        return ["host", "port", "password", "db"]

    def _create_pool(self):
        try:
            self._pool = redis.ConnectionPool(**self._con_data)
        except Exception as e:
            raise RuntimeError(f"Redis连接池初始化失败: {str(e)}") from e

    def close_pool(self) -> None:
        """关闭连接池释放资源"""
        if hasattr(self, "_pool") and self._pool:
            self._pool.disconnect()

    def __del__(self):
        """析构函数,自动关闭连接池"""
        self.close_pool()
        
        
class DatabaseConnectionPoolFactory:
    """数据库连接池工厂:封装连接池对象的创建"""
    # 策略映射表:键为数据库类型,值为对应的连接池类
    _STRATEGY_MAP = {
        "mysql": MysqlConnectionPool,
        "redis": RedisConnectionPool
    }

    @staticmethod
    @overload
    def get_strategy(
        database_type: Literal["mysql"],
        database_configuration_name: str
    ) -> MysqlConnectionPool:
        ...

    @staticmethod
    @overload
    def get_strategy(
        database_type: Literal["redis"],
        database_configuration_name: str
    ) -> RedisConnectionPool:
        ...

    @staticmethod
    def get_strategy(
        database_type: Literal["mysql", "redis"] | str,  # 覆盖有效类型+任意字符串
        database_configuration_name: str
    ) -> Union[MysqlConnectionPool,RedisConnectionPool]:
        """根据数据库类型获取对应的数据库连接池实例
        :param database_type: 数据库类型(比如 mysql/redis)
        :param database_configuration_name: database.toml 中的配置项
        :return: 数据库连接池实例
        :raises ValueError: 不支持的数据库类型
        """
        strategy_class = DatabaseConnectionPoolFactory._STRATEGY_MAP.get(database_type)
        if not strategy_class:
            raise ValueError(f"不支持的数据库类型:{database_type}")
        # 创建并返回连接池实例(传入配置名称)
        return strategy_class(database_configuration_name)


def main() -> None:
    
    """测试"""
    mysql_connection = DatabaseConnectionPoolFactory.get_strategy(database_type="mysql", database_configuration_name="mysql")
    sql = "SELECT * FROM user "
    print(mysql_connection.select_database(sql))
    mysql_connection.close_pool()

    # redis_connection = DatabaseConnectionPoolFactory.get_strategy(database_type="redis", database_configuration_name="redis")
    # r = redis.Redis(connection_pool=redis_connection._pool)
    # aaa = r.hget("child_hash", "192.168.9.118")
    # print(aaa)
    # redis_connection.close_pool()


if __name__ == "__main__":
    main()

由于 redis-py 库自带连接池管理,且我本身没有特别的需求,封装起来特别简单。

项目中使用


启动文件中,运行测试用例之前实例化需要的数据库连接池,全部结束后再关闭数据库连接池。

比如我现在的项目需要操作 mysql 和 redis,可以这样写:

import logging
import sys
import argparse

from common.database import DatabaseConnectionPoolFactory
import pytest
from pytest import ExitCode  

# 自定义日志格式
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s'
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
logger = logging.getLogger(__name__)

# 运行测试
def run_tests(parser: argparse.ArgumentParser) -> int:
    """运行测试"""
    # 解析自定义参数和 pytest 参数
    customize_args, pytest_args = parser.parse_known_args()
    
    # 如果用户请求帮助,则打印帮助信息并退出
    if customize_args.help:
        parser.print_help()
        print("\n=== pytest 原生参数 help 信息 ===")
        pytest.main(["-h"])
        return 0  

    # 开始运行测试
    try:
        logger.info(f"开始运行测试,,pytest参数: {pytest_args}")
        exit_code = pytest.main(pytest_args)

        exit_code_value = exit_code.value if isinstance(exit_code, ExitCode) else exit_code
        exit_messages = {
            0: "✅ 全部测试用例通过",
            1: "⚠️ 部分测试用例未通过",
            2: "❌ 测试过程中有中断或其他非正常终止",
            3: "❌ 内部错误",
            4: "❌ 命令行参数错误",
            5: "❌ 没有收集到任何测试用例"
        }

        logger.info(exit_messages.get(exit_code_value, f"❓ 未知的退出码: {exit_code_value}"))
        return exit_code_value

    except Exception:
        logger.exception("运行测试时发生致命错误:")
        return 1

# 构建参数解析器,传递 pytest 参数,后续根据需求创建自定义参数
def parse_arguments() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="pytest 启动脚本(支持自定义参数 + pytest 参数)",
        add_help=False  # 隐藏默认帮助信息,避免与 pytest 的 -h 冲突
    )

    parser.add_argument(
        "-h", "--help",
        action="store_true",
        help="显示帮助信息(包含自定义参数和pytest原生参数)"
    )

    return parser

# 主函数
def main():
    # 创建数据库连接池
    mysql_connection = DatabaseConnectionPoolFactory.get_strategy(
        database_type="mysql", 
        database_configuration_name="mysql")
    logger.info("已创建MySQL数据库连接池")
    redis_connection = DatabaseConnectionPoolFactory.get_strategy(
        database_type="redis", 
        database_configuration_name="redis")
    logger.info("已创建Redis数据库连接池")
    # 获取参数
    parser = parse_arguments()
    exit_code = run_tests(parser)
    # 关闭数据库连接池
    mysql_connection.close_pool()
    logger.info("已关闭MySQL数据库连接池")
    redis_connection.close_pool()
    logger.info("已关闭Redis数据库连接池")
    # 退出
    sys.exit(exit_code)
    logger.info("已退出")

if __name__ == "__main__":
    main()

本文中只支持了 mysql 和 redis,是因为我写作本文时,只需要操作这两种数据库,未来我可能会扩展更多的数据库支持。但最重要的,读者应该根据自己的需求,参考我的想法,自己去实现需求。


此时的目录结构如下:

接口自动化测试框架/
├── business/                      # 业务逻辑层(接口封装层,解耦用例与接口细节)
│   └── __init__.py
├── common/                        # 【核心】公共工具模块(复用性代码,核心层)
│   ├── __init__.py                # 标记为Python包
│   ├── get_data.py                # 数据获取工具,内有 load_toml 函数
│   └── database.py                # 数据库封装工具
├── cases/                         # 自动化用例目录(仅调用业务层,不写具体逻辑)
│   └── __init__.py
├── config/                        # 配置文件目录(与代码解耦,环境切换核心)
│   └── __init__.py
├── data/                          # 测试数据目录(数据驱动,与用例分离)
├── docs/                          # 项目文档目录(团队协作必备)
│   ├── api_docs.md                # 接口文档(地址、参数、请求方式、响应示例)
│   └── usage_guide.md             # 使用指南(环境搭建、运行命令、问题排查)
├── pytest.toml                    # pytest 配置文件
├── run.py                         # 项目启动文件(入口)
├── requirements.txt               # 项目依赖包(指定版本,避免环境问题)
└── README.md                      # 项目说明(必选,快速上手:环境、运行、目录说明)

THEEND



© 转载需要保留原始链接,未经明确许可,禁止商业使用。CC BY-NC-ND 4.0