
超时装饰器
虽然集成的 pytest 和 allure ,本身功能已经非常丰富,但可能还是不满足需求。比如 Pytest 系列中提到超时标记插件在 Windows 平台上存在问题,只能自己实现个装饰器。
- 在 common/ 目录下,新建 timeout.py
import concurrent.futures
import functools
from typing import Any, Callable
def timeout(seconds: int = 10) -> Callable:
"""
装饰器:为函数添加超时限制,若执行时间超过指定秒数则抛出 AssertionError。
:param seconds: 超时时间(单位:秒),默认为 10。
:return: 被包装的函数。
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(func, *args, **kwargs)
try:
return future.result(timeout=seconds)
except concurrent.futures.TimeoutError:
error_msg = f"测试用例 {func.__name__} 已执行超过 {seconds} 秒。"
raise AssertionError(error_msg) from None
return wrapper
return decorator
- 使用 @timeout 装饰器
from common.timeout import timeout
import time
@timeout(3)
def test_case1():
time.sleep(5)
assert True
@timeout(5)
def test_case2():
time.sleep(3)
assert True
- 运行结果
===================================================================== test session starts ======================================================================
configfile: pytest.ini
plugins: allure-pytest-2.15.0
collected 2 items
tests\test_test.py F. [100%]
=========================================================================== FAILURES ===========================================================================
__________________________________________________________________________ test_case1 __________________________________________________________________________
args = (), kwargs = {}, executor = <concurrent.futures.thread.ThreadPoolExecutor object at 0x0000016D32417CB0>
future = <Future at 0x16d32514050 state=finished returned NoneType>, error_msg = '测试用例 test_case1 已执行超过 3 秒。'
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(func, *args, **kwargs)
try:
return future.result(timeout=seconds)
except concurrent.futures.TimeoutError:
error_msg = f"测试用例 {func.__name__} 已执行超过 {seconds} 秒。"
> raise AssertionError(error_msg) from None
E AssertionError: 测试用例 test_case1 已执行超过 3 秒。
common\timeout.py:40: AssertionError
=================================================================== short test summary info ====================================================================
FAILED tests/test_test.py::test_case1 - AssertionError: 测试用例 test_case1 已执行超过 3 秒。
================================================================= 1 failed, 1 passed in 8.13s ==================================================================
从文件中读取参数化数据
不想要把参数化数据写在代码里,而是放进文件中,比如 csv 文件或者更多类型文件。这个需求我们来实现一下。
如果只想实现对 csv 文件的支持,单独实现个类或函数即可。如果现在或未来想扩展对其他类型文件的支持,我们可以使用工厂模式设计。
- 确定文件中数据结构,比如 csv 设计成这样:

Wps 和 Office 都可以编辑 csv 文件。
需要注意使用双引号包括字符串类型的数字,然后使用任意编辑器修缮数据(编码改为UTF8、双引号数量问题)。
mark 列确定本行数据应用于哪个 mark 标记。
- 确定读取 csv 文件数据后,需要转换的字典数据结构(或者其他):
{
"param": "name,age,phone",
"value": [("张三树",1,"17611111111"),("李四木",2,"17611111112"),.......]
"mark": {
"smoke": [0,1,2]
"user": [3,4,5]
}
}
- 在 common/ 目录下,新建 parameterized_data_from_files.py,根据前两步骤实现读取 csv 数据方法的逻辑:
from abc import ABC, abstractmethod
import csv
import re
from typing import Dict, Any, Tuple, List
from collections import OrderedDict
class ReadParameterizedData(ABC):
"""
抽象类,用于从文件中读取参数化数据
"""
@abstractmethod
def read_data(self, file_path: str) -> Dict:
"""
抽象方法,用于从文件中读取参数化数据
:return: 参数化数据列表
"""
class ReadParameterizedDataFromCsvFile(ReadParameterizedData):
"""
从CSV文件中读取参数化数据。
支持通过 'mark' 列对数据行进行标记分组。
数字:未加引号的转为 int,加了引号的保持为 str。
"""
@staticmethod
def _is_quoted(s: str) -> bool:
"""判断字符串是否被双引号包围"""
s = s.strip()
return len(s) >= 2 and s.startswith('"') and s.endswith('"')
@staticmethod
def _parse_value(s: str):
"""根据是否加引号决定是否转为 int"""
stripped = s.strip()
# 检查是否被引号包围
if ReadParameterizedDataFromCsvFile._is_quoted(stripped):
# 去掉外层引号,保留为字符串
return stripped[1:-1]
else:
# 尝试转换为整数
try:
return int(stripped)
except ValueError:
return stripped
def read_data(self, file_path: str) -> Dict[str, Any]:
"""
从CSV文件读取参数化数据,支持类型推断。
:param file_path: CSV文件路径
:return: 包含 param(参数名)、value(参数值列表)、mark(标记索引)的字典
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
lines = [line.rstrip('\n\r') for line in file if line.strip()]
if not lines:
return {"param": "", "value": []}
# 构建变量名参数
header_line = lines[0]
header_reader = csv.reader([header_line])
fieldnames = next(header_reader)
has_mark = 'mark' in fieldnames
param_keys = [key for key in fieldnames if key != 'mark']
param_str = ','.join(param_keys)
# 构建参数数据结构
parameterized_data = {
"param": param_str,
"value": []
}
if has_mark:
parameterized_data["mark"] = {}
# 解析数据行
csv_reader = csv.reader(lines[1:])
for line_num, row in enumerate(csv_reader, start=2):
raw_line = lines[line_num - 1]
pattern = r'(?:[^,"]|"(?:[^"]*)")+'
raw_fields = [match.strip() for match in re.findall(pattern, raw_line)]
if len(raw_fields) < len(fieldnames):
raw_fields += [''] * (len(fieldnames) - len(raw_fields))
# 构造 {fieldname: raw_string} 映射
raw_field_map = OrderedDict()
for i, key in enumerate(fieldnames):
if i < len(raw_fields):
raw_field_map[key] = raw_fields[i]
else:
raw_field_map[key] = ''
# 构建参数值列表
value_tuple = []
for key in param_keys:
raw_val = raw_field_map[key]
parsed = ReadParameterizedDataFromCsvFile._parse_value(raw_val)
value_tuple.append(parsed)
parameterized_data["value"].append(tuple(value_tuple))
# 处理 mark 标记
if has_mark and 'mark' in raw_field_map:
mark_raw = raw_field_map['mark']
mark_value = ReadParameterizedDataFromCsvFile._parse_value(mark_raw)
mark_value = mark_value.strip() if isinstance(mark_value, str) else str(mark_value).strip()
if mark_value:
if mark_value not in parameterized_data["mark"]:
parameterized_data["mark"][mark_value] = []
parameterized_data["mark"][mark_value].append(line_num - 2)
return parameterized_data
except FileNotFoundError:
raise FileNotFoundError(f"CSV文件未找到: {file_path}")
except PermissionError:
raise PermissionError(f"无权限读取文件: {file_path}")
except Exception as e:
raise RuntimeError(f"读取CSV文件时发生错误: {e}")
如果想实现对其他文件的支持,只需要添加其他文件的类,比如 ReadParameterizedDataFromJsonFile ,但是要确保输出的数据结构一致,因为后续还要再处理,数据才能应用于参数化。
- 实现 parametrized_data_from_file 方法进一步处理 mark 标记数据:
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, List
from collections import OrderedDict
import pytest
class ReadParameterizedData(ABC):
"""
抽象类,用于从文件中读取参数化数据
"""
@abstractmethod
def read_data(self, file_path: str) -> Dict:
"""
抽象方法,用于从文件中读取参数化数据
:return: 参数化数据列表
"""
def parametrized_data_from_file(self, file_path: str) -> Tuple[str, List[Tuple]]:
"""
从CSV文件读取参数化数据。
:param file_path: CSV文件路径
:return: 包含 param(参数名)、value(参数值列表, 支持 mark 标记)
"""
parametrized_data = self.read_data(file_path)
if not isinstance(parametrized_data, dict):
raise ValueError("Expected dictionary format from file.")
if 'param' not in parametrized_data:
raise ValueError("没有 param 参数")
if 'value' not in parametrized_data:
raise ValueError("没有 value 参数")
if 'mark' in parametrized_data:
for key, value in parametrized_data['mark'].items():
for i in value :
mark_obj = getattr(pytest.mark, key)
# 将标记添加到数据中
parametrized_data['value'][i] = pytest.param(*parametrized_data['value'][i], marks=mark_obj)
return parametrized_data['param'], parametrized_data['value']
- 实现 ParameterizedFileTypeSelector 类,支持只调用一个接口,自动选择不同类型文件类:
from pathlib import Path
from typing import Tuple, List
class ParameterizedFileTypeSelector:
"""
参数化数据源选择器
"""
@staticmethod
def _get_handler(file_path: str) -> ReadParameterizedData:
ext = Path(file_path).suffix.lower()
if ext == '.csv':
return ReadParameterizedDataFromCsvFile()
# 后续可以轻松扩展支持 .yaml / .json 等
raise ValueError(f"不支持的类型: {ext}")
def parameterized_data(self, file_path: str) -> Tuple[str, List[Tuple]]:
handler = self._get_handler(file_path)
return handler.parametrized_data_from_file(file_path)
- 完整代码如下:
from abc import ABC, abstractmethod
import csv
import re
from pathlib import Path
from typing import Dict, Any, Tuple, List
from collections import OrderedDict
import pytest
class ReadParameterizedData(ABC):
"""
抽象类,用于从文件中读取参数化数据
"""
@abstractmethod
def read_data(self, file_path: str) -> Dict:
"""
抽象方法,用于从文件中读取参数化数据
:return: 参数化数据列表
"""
def parametrized_data_from_file(self, file_path: str) -> Tuple[str, List[Tuple]]:
"""
从CSV文件读取参数化数据。
:param file_path: CSV文件路径
:return: 包含 param(参数名)、value(参数值列表, 支持 mark 标记)
"""
parametrized_data = self.read_data(file_path)
if not isinstance(parametrized_data, dict):
raise ValueError("Expected dictionary format from file.")
if 'param' not in parametrized_data:
raise ValueError("没有 param 参数")
if 'value' not in parametrized_data:
raise ValueError("没有 value 参数")
if 'mark' in parametrized_data:
for key, value in parametrized_data['mark'].items():
for i in value :
mark_obj = getattr(pytest.mark, key)
# 将标记添加到数据中
parametrized_data['value'][i] = pytest.param(*parametrized_data['value'][i], marks=mark_obj)
return parametrized_data['param'], parametrized_data['value']
class ReadParameterizedDataFromCsvFile(ReadParameterizedData):
"""
从CSV文件中读取参数化数据。
支持通过 'mark' 列对数据行进行标记分组。
数字:未加引号的转为 int,加了引号的保持为 str。
"""
@staticmethod
def _is_quoted(s: str) -> bool:
"""判断字符串是否被双引号包围"""
s = s.strip()
return len(s) >= 2 and s.startswith('"') and s.endswith('"')
@staticmethod
def _parse_value(s: str):
"""根据是否加引号决定是否转为 int"""
stripped = s.strip()
# 检查是否被引号包围
if ReadParameterizedDataFromCsvFile._is_quoted(stripped):
# 去掉外层引号,保留为字符串
return stripped[1:-1]
else:
# 尝试转换为整数
try:
return int(stripped)
except ValueError:
return stripped
def read_data(self, file_path: str) -> Dict[str, Any]:
"""
从CSV文件读取参数化数据,支持类型推断。
:param file_path: CSV文件路径
:return: 包含 param(参数名)、value(参数值列表)、mark(标记索引)的字典
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
lines = [line.rstrip('\n\r') for line in file if line.strip()]
if not lines:
return {"param": "", "value": []}
# 构建变量名参数
header_line = lines[0]
header_reader = csv.reader([header_line])
fieldnames = next(header_reader)
has_mark = 'mark' in fieldnames
param_keys = [key for key in fieldnames if key != 'mark']
param_str = ','.join(param_keys)
# 构建参数数据结构
parameterized_data = {
"param": param_str,
"value": []
}
if has_mark:
parameterized_data["mark"] = {}
# 解析数据行
csv_reader = csv.reader(lines[1:])
for line_num, row in enumerate(csv_reader, start=2):
raw_line = lines[line_num - 1]
pattern = r'(?:[^,"]|"(?:[^"]*)")+'
raw_fields = [match.strip() for match in re.findall(pattern, raw_line)]
if len(raw_fields) < len(fieldnames):
raw_fields += [''] * (len(fieldnames) - len(raw_fields))
# 构造 {fieldname: raw_string} 映射
raw_field_map = OrderedDict()
for i, key in enumerate(fieldnames):
if i < len(raw_fields):
raw_field_map[key] = raw_fields[i]
else:
raw_field_map[key] = ''
# 构建参数值列表
value_tuple = []
for key in param_keys:
raw_val = raw_field_map[key]
parsed = ReadParameterizedDataFromCsvFile._parse_value(raw_val)
value_tuple.append(parsed)
parameterized_data["value"].append(tuple(value_tuple))
# 处理 mark 标记
if has_mark and 'mark' in raw_field_map:
mark_raw = raw_field_map['mark']
mark_value = ReadParameterizedDataFromCsvFile._parse_value(mark_raw)
mark_value = mark_value.strip() if isinstance(mark_value, str) else str(mark_value).strip()
if mark_value:
if mark_value not in parameterized_data["mark"]:
parameterized_data["mark"][mark_value] = []
parameterized_data["mark"][mark_value].append(line_num - 2)
return parameterized_data
except FileNotFoundError:
raise FileNotFoundError(f"CSV文件未找到: {file_path}")
except PermissionError:
raise PermissionError(f"无权限读取文件: {file_path}")
except Exception as e:
raise RuntimeError(f"读取CSV文件时发生错误: {e}")
class ParameterizedFileTypeSelector:
"""
参数化数据源选择器
"""
@staticmethod
def _get_handler(file_path: str) -> ReadParameterizedData:
ext = Path(file_path).suffix.lower()
if ext == '.csv':
return ReadParameterizedDataFromCsvFile()
# 后续可以轻松扩展支持 .yaml / .json 等
raise ValueError(f"不支持的类型: {ext}")
def parameterized_data(self, file_path: str) -> Tuple[str, List[Tuple]]:
handler = self._get_handler(file_path)
return handler.parametrized_data_from_file(file_path)
- 使用时,测试用例的参数必须与 csv 文件中参数一致(不包括mark),顺序也要一致(示例:name, age, phone):
import pytest
from common.parameterized_data_from_files import ParameterizedFileTypeSelector
parameterized_data_from_file = ParameterizedFileTypeSelector()
parametrized_param, parametrized_data = parameterized_data_from_file.parameterized_data(
file_path=r"D:\Projects\Python\ApiAutomation\parameterized_files\test.csv")
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
print(name, age, phone)
执行 python .\app.py -m user :
===================================================================== test session starts ============================================================================================================ test session starts ======================================================================
configfile: pytest.ini
plugins: allure-pytest-2.15.0
collected 5 items / 3 deselected / 2 selected
tests\test_test.py FF [100%]
=========================================================================== FAILURES ===========================================================================
__________________________________________________________ test_case1[\u674e\u56db-25-17611111112_0] ___________________________________________________________
name = '李四', age = 25, phone = '17611111112'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('李四', 25, '17611111112')
E assert False
tests\test_test.py:30: AssertionError
__________________________________________________________ test_case1[\u674e\u56db-25-17611111112_1] ___________________________________________________________
name = '李四', age = 25, phone = '17611111112'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('李四', 25, '17611111112')
E assert False
tests\test_test.py:30: AssertionError
=================================================================== short test summary info ====================================================================
FAILED tests/test_test.py::test_case1[\u674e\u56db-25-17611111112_0] - AssertionError: ('李四', 25, '17611111112')
FAILED tests/test_test.py::test_case1[\u674e\u56db-25-17611111112_1] - AssertionError: ('李四', 25, '17611111112')
=============================================================== 2 failed, 3 deselected in 0.09s ================================================================
执行 python .\app.py:
===================================================================== test session starts ======================================================================
configfile: pytest.ini
plugins: allure-pytest-2.15.0
collected 5 items
tests\test_test.py FFFFF [100%]
=========================================================================== FAILURES ===========================================================================
__________________________________________________________ test_case1[\u5f20\u4e09-24-17611111111_0] ___________________________________________________________
name = '张三', age = 24, phone = '17611111111'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('张三', 24, '17611111111')
E assert False
tests\test_test.py:30: AssertionError
__________________________________________________________ test_case1[\u674e\u56db-25-17611111112_0] ___________________________________________________________
name = '李四', age = 25, phone = '17611111112'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('李四', 25, '17611111112')
E assert False
tests\test_test.py:30: AssertionError
___________________________________________________________ test_case1[\u738b\u4e94-26-17611111113] ____________________________________________________________
name = '王五', age = 26, phone = '17611111113'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('王五', 26, '17611111113')
E assert False
tests\test_test.py:30: AssertionError
__________________________________________________________ test_case1[\u674e\u56db-25-17611111112_1] ___________________________________________________________
name = '李四', age = 25, phone = '17611111112'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('李四', 25, '17611111112')
E assert False
tests\test_test.py:30: AssertionError
__________________________________________________________ test_case1[\u5f20\u4e09-24-17611111111_1] ___________________________________________________________
name = '张三', age = 24, phone = '17611111111'
@pytest.mark.parametrize(parametrized_param, parametrized_data)
def test_case1(name, age, phone):
> assert False , (name, age, phone)
E AssertionError: ('张三', 24, '17611111111')
E assert False
tests\test_test.py:30: AssertionError
=================================================================== short test summary info ====================================================================
FAILED tests/test_test.py::test_case1[\u5f20\u4e09-24-17611111111_0] - AssertionError: ('张三', 24, '17611111111')
FAILED tests/test_test.py::test_case1[\u674e\u56db-25-17611111112_0] - AssertionError: ('李四', 25, '17611111112')
FAILED tests/test_test.py::test_case1[\u738b\u4e94-26-17611111113] - AssertionError: ('王五', 26, '17611111113')
FAILED tests/test_test.py::test_case1[\u674e\u56db-25-17611111112_1] - AssertionError: ('李四', 25, '17611111112')
FAILED tests/test_test.py::test_case1[\u5f20\u4e09-24-17611111111_1] - AssertionError: ('张三', 24, '17611111111')
====================================================================== 5 failed in 0.11s =======================================================================
统一数据库接口
市面上有很多数据库,可能你的测试环境中就不只使用一种数据库。我们可以定义一个统一的接口,根据类型使用数据库。
- 在 configs/ 目录下,新建 database_config.py ,连接信息写在这里:
MYSQL = {
"host": "localhost",
"port": 3306,
"user": "root",
"password": "root",
"database": "test",
"charset": "utf8mb4"
}
- 在 common/ 目录下,新建 database.py ,使用工厂模式构建。以 MySQL 为例:
from abc import ABC, abstractmethod
from typing import Any, List, Dict, Tuple, Union
import pymysql
from pymysql import Connection, MySQLError
from configs.database_config import MYSQL
class DataBase(ABC):
"""
抽象类,用于定义数据库接口
"""
@abstractmethod
def _connect_to_the_database(self, **mysql_config) -> Connection:
pass
class MySQLDataBase(DataBase):
"""
MySQL 数据库类
"""
def _connect_to_the_database(self, **mysql_config) -> Connection:
"""
连接到MySQL数据库
:param mysql_config: 数据库配置,若未提供则使用默认配置
:return: 数据库连接对象
:raises: MySQLError 如果连接失败
"""
try:
if mysql_config:
connection = pymysql.connect(**mysql_config)
else:
connection = pymysql.connect(**MYSQL)
return connection
except MySQLError as e:
raise MySQLError(f"无法连接到数据库: {e}") from e
def select_database(
self,
sql: str,
data: Union[ None, Tuple[Any, ...]] = None,
**mysql_config
) -> Union[List[Dict[str, Any]], List[Tuple[Any, ...]]]:
"""
查询数据库
:param sql: SQL 查询语句
:param data: 查询参数
:param mysql_config: 数据库配置
:return: 查询结果,根据游标类型返回字典列表或元组列表
:raises: MySQLError 如果查询失败
"""
connection = None
try:
connection = self._connect_to_the_database(**mysql_config)
if data:
with connection.cursor() as cursor:
cursor.execute(sql, data)
result = cursor.fetchall()
return list(result)
else:
with connection.cursor() as cursor:
cursor.execute(sql)
result = cursor.fetchall()
return list(result)
except MySQLError as e:
raise MySQLError(f"数据库查询错误: {e}") from e
finally:
if connection:
connection.close()
def change_database(
self,
sql: str,
data: Union[ None, List[Tuple[Any, ...]], Tuple[Any, ...]] = None,
batch_size: int = 1000,
**mysql_config
) -> int:
"""
执行数据库变更操作(如插入、更新、删除)
:param sql: SQL 变更语句
:param data: 变更参数列表
:param batch_size: 批量操作大小
:param mysql_config: 数据库配置
:return: 成功处理的数据条数
:raises: MySQLError 如果变更失败
"""
connection = None
try:
connection = self._connect_to_the_database(**mysql_config)
with connection.cursor() as cursor:
if data:
if type(data) == tuple:
res = cursor.execute(sql, data)
if type(data) == list:
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
cursor.executemany(sql, batch)
res = len(data)
else:
res = cursor.execute(sql)
connection.commit()
return res
except MySQLError as e:
if connection:
connection.rollback()
raise MySQLError(f"数据库变更错误: {e}") from e
finally:
if connection:
connection.close()
class DatabaseTypeSelector:
def __new__(cls, db_type: str):
if db_type == "mysql":
return MySQLDataBase()
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
if __name__ == "__main__":
db = DatabaseTypeSelector("mysql")
result = db.select_database(f"select * from user_v1 limit %s", (1,))
print(result)
示例中,默认 database_config.py 中的连接配置,也支持调用时传入连接配置。可以随时扩展其他类型数据库。
以上示例,仅供参考。读者可以按自己想法去实现。
THEEND

© 转载需要保留原始链接,未经明确许可,禁止商业使用。CC BY-NC-ND 4.0