免责声明:这是一种 hacky 方式,但它有效
(如果它很愚蠢但有效,它并不愚蠢)
所有相关部分将始终使用ClientRegistrationRepository来查找 ClientRegistration(以及范围)。
所以我解决这个问题的方法是围绕InMemoryClientRegistrationRepository. 就我而言,我想允许从客户端请求任何额外的范围,所以我想从查询参数中添加额外的范围scope。
这是此解决方案的示例代码:
@Bean
public ClientRegistrationRepository clientRegistrationRepository(OAuth2ClientProperties properties) {
final List<ClientRegistration> registrations = new ArrayList<>(OAuth2ClientPropertiesRegistrationAdapter.getClientRegistrations(properties).values());
// this is the ClientRegistrationRepository that would be used by default configuration
final ClientRegistrationRepository parent = new InMemoryClientRegistrationRepository(registrations);
// this lambda is our wrapper around the configuration based ClientRegistrationRepository
return (registrationId) -> {
final ClientRegistration clientRegistration = parent.findByRegistrationId(registrationId);
if (clientRegistration == null) {
return null;
}
final HttpServletRequest request = Optional.ofNullable(RequestContextHolder.getRequestAttributes())
.filter(ServletRequestAttributes.class::isInstance)
.map(ServletRequestAttributes.class::cast)
.map(ServletRequestAttributes::getRequest)
.orElse(null);
final String query;
if (request == null || (query = request.getQueryString()) == null) {
return clientRegistration;
}
final List<String> scopeQueryParam = parseQuery(query).get(OAuth2ParameterNames.SCOPE);
if (scopeQueryParam == null) {
return clientRegistration;
}
final Set<String> scopes = scopeQueryParam.stream()
.flatMap((v) -> Arrays.stream(v.split(" ")))
.collect(Collectors.toSet());
if (clientRegistration.getScopes().containsAll(scopes)) {
return clientRegistration;
}
final Set<String> resultingScopes = new HashSet<>(scopes);
resultingScopes.addAll(clientRegistration.getScopes());
return ClientRegistration.withClientRegistration(clientRegistration)
.scope(resultingScopes)
.build();
};
}
private static MultiValueMap<String, String> parseQuery(String query) {
final MultiValueMap<String, String> result = new LinkedMultiValueMap<>();
final String[] pairs = query.split("&");
String[] pair;
for (String _pair : pairs) {
pair = _pair.split("=");
if (pair.length >= 1) {
final List<String> values = result.computeIfAbsent(URLDecoder.decode(pair[0], StandardCharsets.UTF_8), (k) -> new ArrayList<>());
if (pair.length >= 2) {
values.add(URLDecoder.decode(pair[1], StandardCharsets.UTF_8));
}
}
}
return result;
}