#!/usr/bin/env python # -*- coding: utf-8 -*- """ 并发获取密集库位处理器测试 验证优化后的处理器在高并发场景下的安全性 """ import asyncio import pytest import uuid from typing import Dict, Any, List from unittest.mock import Mock, patch, AsyncMock from datetime import datetime # 假设的测试环境导入 from services.execution.handlers.storage_location import GetIdleCrowdedSiteBlockHandler from services.execution.task_context import TaskContext from data.models.operate_point_layer import OperatePointLayer from data.models.storage_area import StorageArea class TestConcurrentStorageLocation: """并发获取密集库位测试类""" @pytest.fixture def mock_session(self): """模拟数据库会话""" session = AsyncMock() session.execute = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() return session @pytest.fixture def mock_context(self): """模拟任务上下文""" context = Mock(spec=TaskContext) context.map_id = "test_map_001" context.task_record_id = "task_001" context.block_record_id = "block_001" context.set_variable = Mock() context.set_block_output = Mock() context.is_task_canceled_async = AsyncMock(return_value=False) return context @pytest.fixture def handler(self): """获取密集库位处理器实例""" return GetIdleCrowdedSiteBlockHandler() @pytest.fixture def mock_storage_area(self): """模拟库区数据""" area = Mock(spec=StorageArea) area.area_name = "test_area" area.select_logic = 0 # 默认按库位名称排序 return area @pytest.fixture def mock_storage_layers(self): """模拟库位数据""" layers = [] for i in range(1, 6): # 创建5个库位 layer = Mock(spec=OperatePointLayer) layer.layer_name = f"A00{i}" layer.area_name = "test_area" layer.scene_id = "test_map_001" layer.is_disabled = False layer.is_deleted = False layer.is_locked = False layer.is_occupied = False # 默认空闲 layer.goods_content = "" layer.goods_stored_at = None layer.last_access_at = datetime.now() layers.append(layer) return layers async def simulate_concurrent_requests(self, handler, num_requests: int, input_params: Dict[str, Any], mock_context): """模拟并发请求""" tasks = [] for i in range(num_requests): # 每个请求使用不同的任务ID context = Mock(spec=TaskContext) context.map_id = mock_context.map_id context.task_record_id = f"task_{i:03d}" context.block_record_id = f"block_{i:03d}" context.set_variable = Mock() context.set_block_output = Mock() context.is_task_canceled_async = AsyncMock(return_value=False) block = {"name": f"test_block_{i}"} # 创建异步任务 task = asyncio.create_task( handler.execute(block, input_params.copy(), context) ) tasks.append(task) # 等待所有任务完成 results = await asyncio.gather(*tasks, return_exceptions=True) return results @patch('services.execution.handlers.storage_location.get_async_session') async def test_concurrent_lock_requests(self, mock_get_session, handler, mock_context, mock_storage_area, mock_storage_layers): """测试并发锁定请求 - 应该只有一个成功""" # 设置模拟数据库会话 mock_session = AsyncMock() mock_get_session.return_value.__aenter__.return_value = mock_session # 模拟查询库区信息 area_result = AsyncMock() area_result.scalar_one_or_none.return_value = mock_storage_area # 模拟查询锁定库位(无锁定库位) locked_result = AsyncMock() locked_result.fetchall.return_value = [] # 模拟查询所有库位 all_layers_result = AsyncMock() all_layers_result.scalars.return_value.all.return_value = mock_storage_layers # 模拟查询候选库位 candidate_result = AsyncMock() candidate_result.scalars.return_value.all.return_value = mock_storage_layers # 设置查询结果的顺序 mock_session.execute.side_effect = [ area_result, # 查询库区 locked_result, # 查询锁定库位 all_layers_result, # 查询所有库位 candidate_result, # 查询候选库位 ] # 模拟原子更新操作 - 第一个请求成功,后续失败 update_results = [] successful_updates = 1 # 只允许一个成功 def mock_execute_update(*args, **kwargs): nonlocal successful_updates result = AsyncMock() if successful_updates > 0: result.rowcount = 1 # 更新成功 successful_updates -= 1 else: result.rowcount = 0 # 更新失败(被抢占) return result # 输入参数 input_params = { "groupName": '["test_area"]', "filled": "false", # 放货 "content": "test_goods", "lock": "true", # 需要锁定 "retry": "false" } # 启动多个并发请求 num_requests = 5 # 重置模拟对象状态 mock_session.reset_mock() # 为每个并发请求设置独立的模拟 concurrent_tasks = [] for i in range(num_requests): task_context = Mock(spec=TaskContext) task_context.map_id = "test_map_001" task_context.task_record_id = f"task_{i:03d}" task_context.block_record_id = f"block_{i:03d}" task_context.set_variable = Mock() task_context.set_block_output = Mock() task_context.is_task_canceled_async = AsyncMock(return_value=False) # 为每个请求创建独立的会话模拟 task_session = AsyncMock() # 为每个会话设置查询结果 task_area_result = AsyncMock() task_area_result.scalar_one_or_none.return_value = mock_storage_area task_locked_result = AsyncMock() task_locked_result.fetchall.return_value = [] task_all_layers_result = AsyncMock() task_all_layers_result.scalars.return_value.all.return_value = mock_storage_layers task_candidate_result = AsyncMock() task_candidate_result.scalars.return_value.all.return_value = mock_storage_layers task_session.execute.side_effect = [ task_area_result, task_locked_result, task_all_layers_result, task_candidate_result, mock_execute_update() # 原子更新操作 ] block = {"name": f"test_block_{i}"} # 创建独立的会话上下文管理器 async def get_task_session(): return task_session # 动态patch每个任务的session with patch('services.execution.handlers.storage_location.get_async_session') as mock_task_session: mock_task_session.return_value.__aenter__.return_value = task_session task = asyncio.create_task( handler.execute(block, input_params.copy(), task_context) ) concurrent_tasks.append(task) # 等待所有任务完成 results = await asyncio.gather(*concurrent_tasks, return_exceptions=True) # 验证结果 successful_results = [r for r in results if isinstance(r, dict) and r.get("success")] failed_results = [r for r in results if isinstance(r, dict) and not r.get("success")] # 断言:应该只有一个请求成功 assert len(successful_results) == 1, f"应该只有1个请求成功,实际成功了 {len(successful_results)} 个" assert len(failed_results) == num_requests - 1, f"应该有 {num_requests-1} 个请求失败,实际失败了 {len(failed_results)} 个" # 验证成功的请求返回了正确的库位 successful_result = successful_results[0] assert successful_result["data"]["siteId"] == "A001" assert successful_result["data"]["areaName"] == "test_area" assert successful_result["data"]["locked"] == True @patch('services.execution.handlers.storage_location.get_async_session') async def test_concurrent_no_lock_requests(self, mock_get_session, handler, mock_context, mock_storage_area, mock_storage_layers): """测试并发非锁定请求 - 可能有多个成功但不会冲突""" # 设置模拟数据库会话 mock_session = AsyncMock() mock_get_session.return_value.__aenter__.return_value = mock_session # 为非锁定场景,所有原子更新都应该成功 def mock_execute_update(*args, **kwargs): result = AsyncMock() result.rowcount = 1 # 所有更新都成功 return result # 输入参数(不锁定) input_params = { "groupName": '["test_area"]', "filled": "false", # 放货 "content": "test_goods", "lock": "false", # 不锁定 "retry": "false" } num_requests = 3 results = await self.simulate_concurrent_requests(handler, num_requests, input_params, mock_context) # 在不锁定的情况下,由于我们的原子更新机制,仍然应该保证并发安全 # 每个请求仍然需要通过原子检查占用状态 successful_results = [r for r in results if isinstance(r, dict) and r.get("success")] # 所有请求都应该能够成功(因为我们模拟了足够的空闲库位) assert len(successful_results) >= 1, "至少应该有一个请求成功" async def test_performance_benchmark(self, handler, mock_context): """性能基准测试 - 测试优化后的性能""" input_params = { "groupName": '["test_area"]', "filled": "false", "content": "test_goods", "lock": "true", "retry": "false" } # 模拟大量并发请求 num_requests = 50 start_time = asyncio.get_event_loop().time() with patch('services.execution.handlers.storage_location.get_async_session'): results = await self.simulate_concurrent_requests(handler, num_requests, input_params, mock_context) end_time = asyncio.get_event_loop().time() execution_time = end_time - start_time print(f"并发处理 {num_requests} 个请求耗时: {execution_time:.3f} 秒") print(f"平均每个请求耗时: {execution_time/num_requests*1000:.2f} 毫秒") # 性能断言(根据实际需求调整) assert execution_time < 10.0, f"处理 {num_requests} 个并发请求不应超过10秒,实际耗时 {execution_time:.3f} 秒" async def test_retry_mechanism_under_concurrency(self, handler, mock_context): """测试重试机制在并发环境下的表现""" input_params = { "groupName": '["test_area"]', "filled": "false", "content": "test_goods", "lock": "true", "retry": "true", "retryNum": "3", "retryPeriod": "100" # 100ms重试间隔 } # 模拟中等并发 num_requests = 10 with patch('services.execution.handlers.storage_location.get_async_session'): results = await self.simulate_concurrent_requests(handler, num_requests, input_params, mock_context) # 验证重试机制工作正常 successful_results = [r for r in results if isinstance(r, dict) and r.get("success")] # 在重试机制下,成功率应该有所提升 success_rate = len(successful_results) / num_requests print(f"重试机制下的成功率: {success_rate*100:.1f}%") # 断言至少有一定的成功率 assert success_rate >= 0.1, f"重试机制下成功率过低: {success_rate*100:.1f}%" # 运行测试的主函数 async def run_concurrent_tests(): """运行并发测试""" print("=== 开始并发获取密集库位处理器测试 ===") test_instance = TestConcurrentStorageLocation() # 创建必要的fixtures handler = test_instance.handler() mock_context = test_instance.mock_context() mock_storage_area = test_instance.mock_storage_area() mock_storage_layers = test_instance.mock_storage_layers() try: print("\n1. 测试并发锁定请求...") await test_instance.test_concurrent_lock_requests( None, handler, mock_context, mock_storage_area, mock_storage_layers ) print("✓ 并发锁定请求测试通过") print("\n2. 测试并发非锁定请求...") await test_instance.test_concurrent_no_lock_requests( None, handler, mock_context, mock_storage_area, mock_storage_layers ) print("✓ 并发非锁定请求测试通过") print("\n3. 性能基准测试...") await test_instance.test_performance_benchmark(handler, mock_context) print("✓ 性能基准测试通过") print("\n4. 重试机制并发测试...") await test_instance.test_retry_mechanism_under_concurrency(handler, mock_context) print("✓ 重试机制并发测试通过") print("\n=== 所有测试完成 ===") except Exception as e: print(f"❌ 测试失败: {str(e)}") raise if __name__ == "__main__": # 直接运行测试 asyncio.run(run_concurrent_tests())