test: 提升测试覆盖率 - 添加拦截器和UrlValidator测试

- 新增 ApiResponseWrapperInterceptorTest (完整测试)
- 新增 ApiKeyAuthInterceptorTest (完整测试)
- 新增 UrlValidatorTest (完整测试)
- 覆盖率提升:
  - 指令覆盖率: 81.89% → 83.59%
  - 分支覆盖率: 51.55% → 57.12%
  - 行覆盖率: 88.48% → 90.51%
- 新增测试用例覆盖:
  - API版本头设置逻辑
  - API Key认证流程(null/空白/吊销/哈希验证)
  - URL验证(协议/localhost/私有IP/特殊地址)
  - 边界条件和异常处理
This commit is contained in:
Your Name
2026-03-02 15:22:12 +08:00
parent fe1e426389
commit 3e2d1ece71
3 changed files with 599 additions and 0 deletions

View File

@@ -0,0 +1,232 @@
package com.mosquito.project.web;
import com.mosquito.project.persistence.entity.ApiKeyEntity;
import com.mosquito.project.persistence.repository.ApiKeyRepository;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.time.OffsetDateTime;
import java.util.Base64;
import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
@DisplayName("ApiKeyAuthInterceptor 测试")
class ApiKeyAuthInterceptorTest {
@Mock
private ApiKeyRepository apiKeyRepository;
@Mock
private HttpServletRequest request;
@Mock
private HttpServletResponse response;
@Mock
private Object handler;
private ApiKeyAuthInterceptor interceptor;
@BeforeEach
void setUp() {
interceptor = new ApiKeyAuthInterceptor(apiKeyRepository);
}
@Test
@DisplayName("应该拒绝null API Key")
void shouldRejectNullApiKey() {
// Given
when(request.getHeader("X-API-Key")).thenReturn(null);
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
verify(apiKeyRepository, never()).findByKeyPrefix(anyString());
}
@Test
@DisplayName("应该拒绝空白API Key")
void shouldRejectBlankApiKey() {
// Given
when(request.getHeader("X-API-Key")).thenReturn(" ");
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
@Test
@DisplayName("应该拒绝不存在的API Key前缀")
void shouldRejectNonExistentKeyPrefix() {
// Given
when(request.getHeader("X-API-Key")).thenReturn("test-api-key-12345");
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.empty());
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
@Test
@DisplayName("应该拒绝已吊销的API Key")
void shouldRejectRevokedApiKey() {
// Given
String apiKey = "test-api-key-12345";
ApiKeyEntity entity = new ApiKeyEntity();
entity.setRevokedAt(OffsetDateTime.now());
when(request.getHeader("X-API-Key")).thenReturn(apiKey);
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity));
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
@Test
@DisplayName("应该拒绝哈希不匹配的API Key")
void shouldRejectMismatchedApiKeyHash() throws Exception {
// Given
String apiKey = "test-api-key-12345";
byte[] salt = new byte[16];
String saltBase64 = Base64.getEncoder().encodeToString(salt);
ApiKeyEntity entity = new ApiKeyEntity();
entity.setSalt(saltBase64);
entity.setKeyHash("wrong-hash");
entity.setRevokedAt(null);
when(request.getHeader("X-API-Key")).thenReturn(apiKey);
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity));
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
@Test
@DisplayName("应该接受有效的API Key")
void shouldAcceptValidApiKey() throws Exception {
// Given
String apiKey = "test-api-key-12345";
byte[] salt = new byte[16];
String saltBase64 = Base64.getEncoder().encodeToString(salt);
// 计算正确的哈希
javax.crypto.SecretKeyFactory skf = javax.crypto.SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
javax.crypto.spec.PBEKeySpec spec = new javax.crypto.spec.PBEKeySpec(apiKey.toCharArray(), salt, 185000, 256);
byte[] derived = skf.generateSecret(spec).getEncoded();
String correctHash = Base64.getEncoder().encodeToString(derived);
ApiKeyEntity entity = new ApiKeyEntity();
entity.setSalt(saltBase64);
entity.setKeyHash(correctHash);
entity.setRevokedAt(null);
when(request.getHeader("X-API-Key")).thenReturn(apiKey);
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity));
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isTrue();
verify(response, never()).setStatus(anyInt());
verify(request).setAttribute(eq("apiKeyPrefix"), anyString());
}
@Test
@DisplayName("应该处理短API Key")
void shouldHandleShortApiKey() {
// Given
String shortKey = "short";
when(request.getHeader("X-API-Key")).thenReturn(shortKey);
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.empty());
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(apiKeyRepository).findByKeyPrefix("short");
}
@Test
@DisplayName("应该处理加密异常")
void shouldHandleCryptoException() {
// Given
String apiKey = "test-api-key-12345";
ApiKeyEntity entity = new ApiKeyEntity();
entity.setSalt("invalid-base64!!!"); // 无效的Base64会导致异常
entity.setKeyHash("some-hash");
entity.setRevokedAt(null);
when(request.getHeader("X-API-Key")).thenReturn(apiKey);
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.of(entity));
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
@Test
@DisplayName("应该正确提取API Key前缀")
void shouldExtractCorrectPrefix() {
// Given
String apiKey = "abcdefghijklmnopqrstuvwxyz";
when(request.getHeader("X-API-Key")).thenReturn(apiKey);
when(apiKeyRepository.findByKeyPrefix("abcdefghijkl")).thenReturn(Optional.empty());
// When
interceptor.preHandle(request, response, handler);
// Then
verify(apiKeyRepository).findByKeyPrefix("abcdefghijkl");
}
@Test
@DisplayName("应该处理带空格的API Key")
void shouldHandleApiKeyWithSpaces() {
// Given
String apiKey = " test-key-123 ";
when(request.getHeader("X-API-Key")).thenReturn(apiKey);
when(apiKeyRepository.findByKeyPrefix(anyString())).thenReturn(Optional.empty());
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isFalse();
// 前缀应该被trim
verify(apiKeyRepository).findByKeyPrefix(contains("test-key"));
}
}

View File

@@ -0,0 +1,203 @@
package com.mosquito.project.web;
import com.mosquito.project.config.ApiVersion;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.web.servlet.ModelAndView;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
@DisplayName("ApiResponseWrapperInterceptor 测试")
class ApiResponseWrapperInterceptorTest {
@Mock
private HttpServletRequest request;
@Mock
private HttpServletResponse response;
@Mock
private Object handler;
@Mock
private ModelAndView modelAndView;
private ApiResponseWrapperInterceptor interceptor;
@BeforeEach
void setUp() {
interceptor = new ApiResponseWrapperInterceptor();
}
@Test
@DisplayName("preHandle应该设置startTime属性并返回true")
void shouldSetStartTimeAndReturnTrue_whenPreHandle() {
// When
boolean result = interceptor.preHandle(request, response, handler);
// Then
assertThat(result).isTrue();
verify(request).setAttribute(eq("startTime"), anyLong());
}
@Test
@DisplayName("postHandle应该为成功响应设置API版本头")
void shouldSetApiVersionHeader_whenResponseIsSuccessful() {
// Given
when(response.getStatus()).thenReturn(200);
when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1");
// When
interceptor.postHandle(request, response, handler, modelAndView);
// Then
verify(response).setHeader(ApiVersion.HEADER_NAME, "v1");
}
@Test
@DisplayName("postHandle应该使用默认版本当请求头为null")
void shouldUseDefaultVersion_whenRequestHeaderIsNull() {
// Given
when(response.getStatus()).thenReturn(200);
when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn(null);
// When
interceptor.postHandle(request, response, handler, modelAndView);
// Then
verify(response).setHeader(ApiVersion.HEADER_NAME, ApiVersion.DEFAULT_VERSION);
}
@Test
@DisplayName("postHandle应该使用默认版本当请求头为空白")
void shouldUseDefaultVersion_whenRequestHeaderIsBlank() {
// Given
when(response.getStatus()).thenReturn(200);
when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn(" ");
// When
interceptor.postHandle(request, response, handler, modelAndView);
// Then
verify(response).setHeader(ApiVersion.HEADER_NAME, ApiVersion.DEFAULT_VERSION);
}
@Test
@DisplayName("postHandle不应该为错误响应设置版本头")
void shouldNotSetVersionHeader_whenResponseIsError() {
// Given
when(response.getStatus()).thenReturn(400);
// When
interceptor.postHandle(request, response, handler, modelAndView);
// Then
verify(response, never()).setHeader(anyString(), anyString());
}
@Test
@DisplayName("postHandle不应该为服务器错误设置版本头")
void shouldNotSetVersionHeader_whenResponseIsServerError() {
// Given
when(response.getStatus()).thenReturn(500);
// When
interceptor.postHandle(request, response, handler, modelAndView);
// Then
verify(response, never()).setHeader(anyString(), anyString());
}
@Test
@DisplayName("afterCompletion应该记录API请求日志")
void shouldLogApiRequest_whenAfterCompletion() {
// Given
when(request.getAttribute("startTime")).thenReturn(System.currentTimeMillis() - 100);
when(request.getRequestURI()).thenReturn("/api/v1/activities");
when(request.getMethod()).thenReturn("GET");
// When
interceptor.afterCompletion(request, response, handler, null);
// Then - 验证没有抛出异常
verify(request).getAttribute("startTime");
verify(request, atLeastOnce()).getRequestURI();
}
@Test
@DisplayName("afterCompletion应该处理非API请求")
void shouldHandleNonApiRequest_whenAfterCompletion() {
// Given
when(request.getAttribute("startTime")).thenReturn(System.currentTimeMillis());
when(request.getRequestURI()).thenReturn("/health");
// When
interceptor.afterCompletion(request, response, handler, null);
// Then - 验证没有抛出异常
verify(request).getAttribute("startTime");
verify(request).getRequestURI();
}
@Test
@DisplayName("afterCompletion应该处理异常情况")
void shouldHandleException_whenAfterCompletion() {
// Given
when(request.getAttribute("startTime")).thenReturn(System.currentTimeMillis());
when(request.getRequestURI()).thenReturn("/api/v1/test");
when(request.getMethod()).thenReturn("POST");
Exception exception = new RuntimeException("Test exception");
// When
interceptor.afterCompletion(request, response, handler, exception);
// Then - 验证没有抛出异常
verify(request).getAttribute("startTime");
}
@Test
@DisplayName("postHandle应该处理2xx范围内的所有成功状态码")
void shouldHandleAllSuccessStatusCodes_whenPostHandle() {
// Test 200 OK
when(response.getStatus()).thenReturn(200);
when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1");
interceptor.postHandle(request, response, handler, modelAndView);
verify(response).setHeader(ApiVersion.HEADER_NAME, "v1");
// Test 201 Created
reset(response, request);
when(response.getStatus()).thenReturn(201);
when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1");
interceptor.postHandle(request, response, handler, modelAndView);
verify(response).setHeader(ApiVersion.HEADER_NAME, "v1");
// Test 204 No Content
reset(response, request);
when(response.getStatus()).thenReturn(204);
when(request.getHeader(ApiVersion.HEADER_NAME)).thenReturn("v1");
interceptor.postHandle(request, response, handler, modelAndView);
verify(response).setHeader(ApiVersion.HEADER_NAME, "v1");
}
@Test
@DisplayName("postHandle应该拒绝3xx重定向状态码")
void shouldNotSetHeaderForRedirectStatus_whenPostHandle() {
// Given
when(response.getStatus()).thenReturn(302);
// When
interceptor.postHandle(request, response, handler, modelAndView);
// Then
verify(response, never()).setHeader(anyString(), anyString());
}
}

View File

@@ -0,0 +1,164 @@
package com.mosquito.project.web;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import static org.assertj.core.api.Assertions.assertThat;
@DisplayName("UrlValidator 测试")
class UrlValidatorTest {
private UrlValidator urlValidator;
@BeforeEach
void setUp() {
urlValidator = new UrlValidator();
}
@Test
@DisplayName("应该拒绝null URL")
void shouldRejectNullUrl() {
assertThat(urlValidator.isAllowedUrl(null)).isFalse();
}
@Test
@DisplayName("应该拒绝空白URL")
void shouldRejectBlankUrl() {
assertThat(urlValidator.isAllowedUrl("")).isFalse();
assertThat(urlValidator.isAllowedUrl(" ")).isFalse();
}
@Test
@DisplayName("应该拒绝相对URL")
void shouldRejectRelativeUrl() {
assertThat(urlValidator.isAllowedUrl("/path/to/resource")).isFalse();
assertThat(urlValidator.isAllowedUrl("path/to/resource")).isFalse();
}
@ParameterizedTest
@ValueSource(strings = {
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>"
})
@DisplayName("应该拒绝不允许的协议")
void shouldRejectDisallowedSchemes(String url) {
assertThat(urlValidator.isAllowedUrl(url)).isFalse();
}
@ParameterizedTest
@ValueSource(strings = {
"http://localhost",
"http://127.0.0.1",
"http://::1",
"http://0.0.0.0",
"https://localhost:8080"
})
@DisplayName("应该拒绝localhost地址")
void shouldRejectLocalhostAddresses(String url) {
assertThat(urlValidator.isAllowedUrl(url)).isFalse();
}
@ParameterizedTest
@ValueSource(strings = {
"http://10.0.0.1",
"http://10.255.255.255",
"http://172.16.0.1",
"http://172.31.255.255",
"http://192.168.0.1",
"http://192.168.255.255"
})
@DisplayName("应该拒绝私有IP地址")
void shouldRejectPrivateIpAddresses(String url) {
assertThat(urlValidator.isAllowedUrl(url)).isFalse();
}
@Test
@DisplayName("应该接受有效的公网URL - google.com")
void shouldAcceptValidPublicUrls() {
// 使用真实存在的公网域名进行测试
assertThat(urlValidator.isAllowedUrl("https://www.google.com")).isTrue();
assertThat(urlValidator.isAllowedUrl("https://github.com")).isTrue();
assertThat(urlValidator.isAllowedUrl("http://www.baidu.com")).isTrue();
}
@Test
@DisplayName("应该拒绝无效的URL语法")
void shouldRejectInvalidUrlSyntax() {
assertThat(urlValidator.isAllowedUrl("not a url")).isFalse();
assertThat(urlValidator.isAllowedUrl("http://")).isFalse();
assertThat(urlValidator.isAllowedUrl("://example.com")).isFalse();
}
@Test
@DisplayName("sanitizeUrl应该返回有效URL的字符串形式")
void shouldSanitizeValidUrl() {
String url = "https://example.com/path";
String sanitized = urlValidator.sanitizeUrl(url);
assertThat(sanitized).isNotNull();
assertThat(sanitized).contains("example.com");
}
@Test
@DisplayName("sanitizeUrl应该对无效URL返回null")
void shouldReturnNullForInvalidUrl() {
assertThat(urlValidator.sanitizeUrl(null)).isNull();
assertThat(urlValidator.sanitizeUrl("")).isNull();
assertThat(urlValidator.sanitizeUrl("http://localhost")).isNull();
assertThat(urlValidator.sanitizeUrl("not a url")).isNull();
}
@Test
@DisplayName("应该处理URL中的大小写")
void shouldHandleUrlCaseInsensitivity() {
assertThat(urlValidator.isAllowedUrl("HTTP://EXAMPLE.COM")).isTrue();
assertThat(urlValidator.isAllowedUrl("HTTPS://EXAMPLE.COM")).isTrue();
}
@Test
@DisplayName("应该拒绝空主机名")
void shouldRejectEmptyHost() {
assertThat(urlValidator.isAllowedUrl("http://")).isFalse();
}
@ParameterizedTest
@ValueSource(strings = {
"http://169.254.0.1", // Link-local
"http://224.0.0.1" // Multicast
})
@DisplayName("应该拒绝特殊用途的IP地址")
void shouldRejectSpecialPurposeIpAddresses(String url) {
assertThat(urlValidator.isAllowedUrl(url)).isFalse();
}
@Test
@DisplayName("应该处理带端口的URL")
void shouldHandleUrlsWithPorts() {
assertThat(urlValidator.isAllowedUrl("https://example.com:443")).isTrue();
assertThat(urlValidator.isAllowedUrl("http://example.com:80")).isTrue();
assertThat(urlValidator.isAllowedUrl("https://example.com:8443")).isTrue();
}
@Test
@DisplayName("应该处理带查询参数的URL")
void shouldHandleUrlsWithQueryParameters() {
assertThat(urlValidator.isAllowedUrl("https://example.com/path?key=value&foo=bar")).isTrue();
}
@Test
@DisplayName("应该处理带片段的URL")
void shouldHandleUrlsWithFragments() {
assertThat(urlValidator.isAllowedUrl("https://example.com/path#section")).isTrue();
}
@Test
@DisplayName("应该拒绝IPv6 loopback地址")
void shouldRejectIpv6LoopbackAddress() {
assertThat(urlValidator.isAllowedUrl("http://[::1]")).isFalse();
assertThat(urlValidator.isAllowedUrl("http://[0:0:0:0:0:0:0:1]")).isFalse();
}
}