diff --git a/build.gradle b/build.gradle index bd35db72ae828e09e62ca2dafc628e73ff38564f..907eca77cf4e2814215ee64a47598f487479786e 100644 --- a/build.gradle +++ b/build.gradle @@ -30,7 +30,14 @@ dependencies { implementation 'org.springframework.boot:spring-boot-starter-log4j2:3.0.4' implementation 'org.springframework.boot:spring-boot-starter-actuator' implementation 'org.springdoc:springdoc-openapi-starter-webmvc-ui:2.0.4' + implementation 'jakarta.validation:jakarta.validation-api:3.0.2' // common + implementation 'com.google.guava:guava:31.1-jre' // common + implementation 'org.apache.commons:commons-text:1.10.0' // common + implementation 'com.auth0:java-jwt:4.3.0' // common + implementation 'org.aspectj:aspectjweaver:1.9.19' // common + implementation 'org.aspectj:aspectjrt:1.9.19' // common implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.15.0-rc1' + testImplementation 'org.springframework.boot:spring-boot-starter-test' } diff --git a/src/main/java/edu/umd/dawn/common/annotations/AspectBase.java b/src/main/java/edu/umd/dawn/common/annotations/AspectBase.java new file mode 100644 index 0000000000000000000000000000000000000000..f6099da05742533dcbfa4e3439b996ddd068808e --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/AspectBase.java @@ -0,0 +1,40 @@ +package edu.umd.dawn.common.annotations; + +import java.lang.annotation.Annotation; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Pointcut; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.stereotype.Component; + +/** + * base class for custom aspects for annotations. includes some helper functions + */ +@Aspect +@Component +public class AspectBase<T extends Annotation> { + + @Pointcut(value = "execution(* *.*(..))") + protected void allMethods() {} + + protected String getFullDescriptor(ProceedingJoinPoint jointPoint) { + String className = ((MethodSignature) jointPoint.getSignature()).getDeclaringTypeName(); + String methodName = + ((MethodSignature) jointPoint.getSignature()).getMethod().getName(); + return className + "." + methodName; + } + + protected T grabAnnotation(ProceedingJoinPoint jointPoint, Class<T> clazz) { + return ((MethodSignature) jointPoint.getSignature()).getMethod().getAnnotation(clazz); + } + + protected T grabAnnotationFromClass(ProceedingJoinPoint jointPoint, Class<T> clazz) { + return (T) + ((MethodSignature) jointPoint.getSignature()).getDeclaringType().getAnnotation(clazz); + } + + // protected T grabMethodArgument(ProceedingJoinPoint jointPoint, Class<T> clazz) { + // return (T) + // ((MethodSignature) jointPoint.getSignature()).getDeclaringType().getAnnotation(clazz); + // } +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/Deprecated.java b/src/main/java/edu/umd/dawn/common/annotations/Deprecated.java new file mode 100644 index 0000000000000000000000000000000000000000..ca1cfdcc83ea931a87364abcffd3ec024ed0b1b3 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/Deprecated.java @@ -0,0 +1,15 @@ +package edu.umd.dawn.common.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * This serves as a stronger deprecation notice - will report warnings + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface Deprecated { + public String value() default ""; +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/DeprecatedAspect.java b/src/main/java/edu/umd/dawn/common/annotations/DeprecatedAspect.java new file mode 100644 index 0000000000000000000000000000000000000000..28acc599009a337c3b78e745a4be3d67d6de1199 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/DeprecatedAspect.java @@ -0,0 +1,27 @@ +package edu.umd.dawn.common.annotations; + +import lombok.extern.log4j.Log4j2; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.springframework.stereotype.Component; + +@Aspect +@Component +@Log4j2 +public class DeprecatedAspect extends AspectBase<Deprecated> { + + @Around("@annotation(Deprecated)") + public Object run(ProceedingJoinPoint jointPoint) throws Exception, Throwable { + Deprecated annotation = grabAnnotation(jointPoint, Deprecated.class); + String descriptor = getFullDescriptor(jointPoint); + + if (!annotation.value().equals("")) { + log.warn(String.format("method %s is deprecated - reason: %s", descriptor, annotation.value())); + } else { + log.warn(String.format("method %s is deprecated", descriptor)); + } + Object proceed = jointPoint.proceed(); + return proceed; + } +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/RoleRestriction.java b/src/main/java/edu/umd/dawn/common/annotations/RoleRestriction.java new file mode 100644 index 0000000000000000000000000000000000000000..1c0b69d1b2144219024c66755fa8aaf220f1d84e --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/RoleRestriction.java @@ -0,0 +1,16 @@ +package edu.umd.dawn.common.annotations; + +import edu.umd.dawn.common.enums.Role; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface RoleRestriction { + + public Role value() default Role.USER; + + boolean disableWarning() default false; +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/RoleRestrictionAspect.java b/src/main/java/edu/umd/dawn/common/annotations/RoleRestrictionAspect.java new file mode 100644 index 0000000000000000000000000000000000000000..b57c860a9b262c6dae372739056283c09a0099a1 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/RoleRestrictionAspect.java @@ -0,0 +1,45 @@ +package edu.umd.dawn.common.annotations; + +import edu.umd.dawn.common.exceptions.BaseExceptions; +import edu.umd.dawn.common.exceptions.DawnException; +import edu.umd.dawn.common.jwt.Claims; +import jakarta.servlet.http.HttpServletRequest; +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +@Aspect +@Component +@Log4j2 +@RequiredArgsConstructor(onConstructor = @__(@Autowired)) // Autowired annotated lombok generated constructor +public class RoleRestrictionAspect extends AspectBase<RoleRestriction> { + + private final HttpServletRequest request; + + @Value("${config.local}") + private boolean local = false; + + @Around("@annotation(RoleRestriction)") + public Object run(ProceedingJoinPoint jointPoint) throws Exception, Throwable { + RoleRestriction annotation = grabAnnotation(jointPoint, RoleRestriction.class); + if (!local) { + Claims claims = (Claims) request.getAttribute("claims"); + if (claims == null) { + throw new DawnException(BaseExceptions.FORBIDDEN); + } + // call db to check + } else if (!annotation.disableWarning()) { + log.warn("RoleRestriction annotation has been disabled - if this is a production environment, consider this" + + " a critical security error. Request with unknown role accessed endpoint with restriction " + + annotation.value().toString() + "."); + } + + Object proceed = jointPoint.proceed(); + return proceed; + } +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/Traceable.java b/src/main/java/edu/umd/dawn/common/annotations/Traceable.java new file mode 100644 index 0000000000000000000000000000000000000000..df9bacb1c0e78d20dbdeddba8e2be2dec4eb3ebd --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/Traceable.java @@ -0,0 +1,16 @@ +package edu.umd.dawn.common.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.apache.logging.log4j.spi.StandardLevel; + +/** + * Traces entry and exit from functions at specified log level + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface Traceable { + StandardLevel level() default StandardLevel.TRACE; +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/TraceableAspect.java b/src/main/java/edu/umd/dawn/common/annotations/TraceableAspect.java new file mode 100644 index 0000000000000000000000000000000000000000..f2afac9b15ceb8f397152c81c19690964d555c61 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/TraceableAspect.java @@ -0,0 +1,31 @@ +package edu.umd.dawn.common.annotations; + +import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.Level; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.springframework.stereotype.Component; + +/** + * Traces entry and exit from functions at specified log level + */ +@Log4j2 +@Aspect +@Component +public class TraceableAspect extends AspectBase<Traceable> { + + @Around("@annotation(Traceable)") + public Object run(ProceedingJoinPoint jointPoint) throws Exception, Throwable { + Traceable traceable = grabAnnotation(jointPoint, Traceable.class); + + String descriptor = getFullDescriptor(jointPoint); + + log.log( + Level.forName(traceable.level().name(), traceable.level().intLevel()), + String.format("entering %s", descriptor)); + Object proceed = jointPoint.proceed(); + log.info(String.format("exiting %s", descriptor)); + return proceed; + } +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/Unfinished.java b/src/main/java/edu/umd/dawn/common/annotations/Unfinished.java new file mode 100644 index 0000000000000000000000000000000000000000..12d2e07d822249a36613a845fdd190d1e0b5029f --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/Unfinished.java @@ -0,0 +1,15 @@ +package edu.umd.dawn.common.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * warns that a method is unfinished + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD}) +public @interface Unfinished { + String value() default ""; +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/UnfinishedAspect.java b/src/main/java/edu/umd/dawn/common/annotations/UnfinishedAspect.java new file mode 100644 index 0000000000000000000000000000000000000000..781dbf523c40adb2bf1576274323d59b80f9afa8 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/UnfinishedAspect.java @@ -0,0 +1,30 @@ +package edu.umd.dawn.common.annotations; + +import lombok.extern.log4j.Log4j2; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.springframework.stereotype.Component; + +/** + * warns that a method is unfinished + */ +@Aspect +@Component +@Log4j2 +public class UnfinishedAspect extends AspectBase<Unfinished> { + + @Around("@annotation(Unfinished)") + public Object run(ProceedingJoinPoint jointPoint) throws Exception, Throwable { + Unfinished unfinished = grabAnnotation(jointPoint, Unfinished.class); + + String descriptor = getFullDescriptor(jointPoint); + if (!unfinished.value().equals("")) { + log.warn(String.format("method %s is unfinished - reason: %s", descriptor, unfinished.value())); + } else { + log.warn(String.format("method %s is unfinished", descriptor)); + } + Object proceed = jointPoint.proceed(); + return proceed; + } +} diff --git a/src/main/java/edu/umd/dawn/common/annotations/Utils.java b/src/main/java/edu/umd/dawn/common/annotations/Utils.java new file mode 100644 index 0000000000000000000000000000000000000000..107fb736e065a271e96007956d80b93182773ce1 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/annotations/Utils.java @@ -0,0 +1,19 @@ +package edu.umd.dawn.common.annotations; + +import java.lang.annotation.Annotation; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.reflect.MethodSignature; + +public class Utils<T> { + + public static String getFullDescriptor(ProceedingJoinPoint jointPoint) { + String className = ((MethodSignature) jointPoint.getSignature()).getDeclaringTypeName(); + String methodName = + ((MethodSignature) jointPoint.getSignature()).getMethod().getName(); + return className + "." + methodName; + } + + public static <T extends Annotation> T grabAnnotation(ProceedingJoinPoint jointPoint, Class<T> clazz) { + return ((MethodSignature) jointPoint.getSignature()).getMethod().getAnnotation(clazz); + } +} diff --git a/src/main/java/edu/umd/dawn/common/converter/CaseInsensitiveEnumConverter.java b/src/main/java/edu/umd/dawn/common/converter/CaseInsensitiveEnumConverter.java new file mode 100644 index 0000000000000000000000000000000000000000..b82f3e81594a3b44b35f4d57f15d25d90bb1f595 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/converter/CaseInsensitiveEnumConverter.java @@ -0,0 +1,16 @@ +package edu.umd.dawn.common.converter; + +import org.springframework.core.convert.converter.Converter; + +public class CaseInsensitiveEnumConverter<T extends Enum<T>> implements Converter<String, T> { + private Class<T> enumClass; + + public CaseInsensitiveEnumConverter(Class<T> enumClass) { + this.enumClass = enumClass; + } + + @Override + public T convert(String from) { + return T.valueOf(enumClass, from.toUpperCase()); + } +} diff --git a/src/main/java/edu/umd/dawn/common/enums/Role.java b/src/main/java/edu/umd/dawn/common/enums/Role.java new file mode 100644 index 0000000000000000000000000000000000000000..ebbc12970c5fbe02ae8b9dafdb6deb23f8f3801b --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/enums/Role.java @@ -0,0 +1,22 @@ +package edu.umd.dawn.common.enums; + +public enum Role { + USER(0), + REPORTER(1), + ADMIN(2), + SUPER(3); + + private int level; + + private Role(int level) { + this.level = level; + } + + public static Role fromInt(int level) { + return Role.values()[level]; + } + + public boolean isAccessAllowed(Role minimum) { + return minimum.level <= level; + } +} diff --git a/src/main/java/edu/umd/dawn/common/exceptions/BaseExceptions.java b/src/main/java/edu/umd/dawn/common/exceptions/BaseExceptions.java index 40e3e9493f4d93f4587892b059e3e91e4ed7c97a..c4f524853b13b416a7848522d934b11f4a52dfad 100644 --- a/src/main/java/edu/umd/dawn/common/exceptions/BaseExceptions.java +++ b/src/main/java/edu/umd/dawn/common/exceptions/BaseExceptions.java @@ -9,4 +9,18 @@ public class BaseExceptions { new DawnExceptionParameters(500, "INTERNAL_SERVER_ERROR", "Internal server error", "out of bounds"); public static final DawnExceptionParameters UNHANDLED_INTERNAL_SERVER_ERROR = new DawnExceptionParameters(500, "INTERNAL_SERVER_ERROR", "Internal server error", "unhandled error"); + public static final DawnExceptionParameters FORBIDDEN = + new DawnExceptionParameters(403, "FORBIDDEN", "user cannot access requested resource", ""); + public static final DawnExceptionParameters INVALID_JWT = new DawnExceptionParameters( + 401, "UNAUTHORIZED", "Invalid or empty JWT provided", "JWT provided is invalid"); + public static final DawnExceptionParameters NOT_FOUND = + new DawnExceptionParameters(404, "NOT_FOUND", "Resource not found", ""); + + public static DawnExceptionParameters INVALID_LIMIT(int limit) { + return new DawnExceptionParameters(400, "BAD_REQUEST", String.format("limit of %d is invalid", limit), ""); + } + + public static DawnExceptionParameters INVALID_OFFSET(int offset) { + return new DawnExceptionParameters(400, "BAD_REQUEST", String.format("offset of %d is invalid", offset), ""); + } } diff --git a/src/main/java/edu/umd/dawn/common/exceptions/CustomExceptionHandler.java b/src/main/java/edu/umd/dawn/common/exceptions/CustomExceptionHandler.java index f4188d60b53a51abf82895db17ddea8f773a9622..c9c06698e44c45f4ef3adfbe5072a27116f27e1d 100644 --- a/src/main/java/edu/umd/dawn/common/exceptions/CustomExceptionHandler.java +++ b/src/main/java/edu/umd/dawn/common/exceptions/CustomExceptionHandler.java @@ -5,7 +5,6 @@ import java.io.StringWriter; import lombok.extern.slf4j.Slf4j; import org.slf4j.MDC; import org.springframework.beans.factory.annotation.Value; -import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.validation.BindException; @@ -55,9 +54,11 @@ public class CustomExceptionHandler { @ExceptionHandler(BindException.class) protected ResponseEntity<Object> handleBindException(BindException ex, WebRequest request) { - FieldError err = ex.getFieldError(); - DawnException wrapped = new DawnException(BaseExceptions.BAD_REQUEST, ex, "Value " + err.getRejectedValue() + " is invalid for field " + err.getField()); + DawnException wrapped = new DawnException( + BaseExceptions.BAD_REQUEST, + ex, + "Value " + err.getRejectedValue() + " is invalid for field " + err.getField()); return returnDawnException(wrapped); } diff --git a/src/main/java/edu/umd/dawn/common/filters/CaseInsensitiveFilter.java b/src/main/java/edu/umd/dawn/common/filters/CaseInsensitiveFilter.java new file mode 100644 index 0000000000000000000000000000000000000000..5abc7f51334ddfb1edd7761dc703321c9c82f5fb --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/filters/CaseInsensitiveFilter.java @@ -0,0 +1,77 @@ +package edu.umd.dawn.common.filters; + +import com.google.common.base.CaseFormat; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import lombok.extern.log4j.Log4j2; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; + +/** + * This filter will convert any underscore case variable to camelcase. + * It is intended to support backwards compatibility with golang requests which had variables written + * like user_id instead of userId. + */ +@Log4j2 +@Component +public class CaseInsensitiveFilter extends OncePerRequestFilter { + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + final Map<String, String[]> formattedParams = new ConcurrentHashMap<>(); + + List<String> invalidParameterCase = new ArrayList<>(); + + for (String param : request.getParameterMap().keySet()) { + String formattedParam = param; + if (param.contains("_")) { + invalidParameterCase.add(param); + formattedParam = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, param); + } + formattedParams.put(formattedParam, request.getParameterValues(param)); + } + + filterChain.doFilter( + new HttpServletRequestWrapper(request) { + @Override + public String getParameter(String name) { + return formattedParams.containsKey(name) + ? formattedParams.get(name)[0] + : null; + } + + @Override + public Enumeration<String> getParameterNames() { + return Collections.enumeration(formattedParams.keySet()); + } + + @Override + public String[] getParameterValues(String name) { + return formattedParams.get(name); + } + + @Override + public Map<String, String[]> getParameterMap() { + return formattedParams; + } + }, + response); + + if (invalidParameterCase.size() > 0) { + log.warn(String.format( + "client is still using deprecated naming scheme for parameter(s) %s", + invalidParameterCase.toString())); + } + } +} diff --git a/src/main/java/edu/umd/dawn/common/interceptor/RequestInterceptor.java b/src/main/java/edu/umd/dawn/common/interceptor/RequestInterceptor.java index 3040313b8e24bce7708b61ff044705290c70fe77..b0e9013a5530468352cb9a653b85a2b0ad9780ae 100644 --- a/src/main/java/edu/umd/dawn/common/interceptor/RequestInterceptor.java +++ b/src/main/java/edu/umd/dawn/common/interceptor/RequestInterceptor.java @@ -68,7 +68,7 @@ public class RequestInterceptor implements HandlerInterceptor { Object p = request.getAttribute(HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE); if (p != null) { - Map<String, String> pathParams = (Map<String, String>)p; + Map<String, String> pathParams = (Map<String, String>) p; pathParams.forEach((k, v) -> parameters.merge(k, v, String::concat)); } @@ -83,8 +83,7 @@ public class RequestInterceptor implements HandlerInterceptor { long executeTime = endTime - startTime; String fullPath = request.getRequestURI(); - String servletPattern = (String) request.getAttribute( - HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE); + String servletPattern = (String) request.getAttribute(HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE); String servletPath = request.getServletPath(); String path = fullPath.replace(servletPath, servletPattern); diff --git a/src/main/java/edu/umd/dawn/common/jwt/Claims.java b/src/main/java/edu/umd/dawn/common/jwt/Claims.java new file mode 100644 index 0000000000000000000000000000000000000000..16b4006ac49a859f290e619d4291759848e1f05f --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/jwt/Claims.java @@ -0,0 +1,20 @@ +package edu.umd.dawn.common.jwt; + +import com.auth0.jwt.interfaces.DecodedJWT; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class Claims { + + private String userId; + + public static Claims build(JWTUtil parser) { + return new Claims(parser.getJwt().getClaim("id").asString()); + } + + public static Claims build(DecodedJWT jwt) { + return new Claims(jwt.getClaim("id").asString()); + } +} diff --git a/src/main/java/edu/umd/dawn/common/jwt/JWTUtil.java b/src/main/java/edu/umd/dawn/common/jwt/JWTUtil.java new file mode 100644 index 0000000000000000000000000000000000000000..36cfaea4396102b35b95377c03057163d1d76df9 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/jwt/JWTUtil.java @@ -0,0 +1,49 @@ +package edu.umd.dawn.common.jwt; + +import com.auth0.jwt.JWT; +import com.auth0.jwt.algorithms.Algorithm; +import com.auth0.jwt.interfaces.DecodedJWT; +import edu.umd.dawn.common.exceptions.BaseExceptions; +import edu.umd.dawn.common.exceptions.DawnException; +import lombok.Getter; + +@Getter +public class JWTUtil { + + private Algorithm algorithm; + private String accessSecret; + private String jwtString; + private DecodedJWT jwt; + + public JWTUtil(String accessSecret, String jwt) { + this.accessSecret = accessSecret; + this.jwtString = jwt; + initAlgorithm(); + decode(); + } + + private void initAlgorithm() { + this.algorithm = Algorithm.HMAC256(accessSecret); + } + + private void decode() { + try { + this.jwt = JWT.require(algorithm) + .acceptLeeway(1) // 1 sec for nbf and iat + .acceptExpiresAt(5) // 5 secs for exp + .withClaimPresence("id") // need id in the claim + .build() + .verify(jwtString); + } catch (Exception e) { + throw new DawnException(BaseExceptions.INVALID_JWT, e); + } + } + + public Claims getClaims() { + return Claims.build(this); + } + + public static JWTUtil parse(String accessSecret, String jwt) { + return new JWTUtil(accessSecret, jwt); + } +} diff --git a/src/main/java/edu/umd/dawn/common/services/UserAuthService.java b/src/main/java/edu/umd/dawn/common/services/UserAuthService.java new file mode 100644 index 0000000000000000000000000000000000000000..e7e7242befa848a00e2ba3b411aa11cac97a8264 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/services/UserAuthService.java @@ -0,0 +1,52 @@ +package edu.umd.dawn.common.services; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import edu.umd.dawn.common.enums.Role; +import edu.umd.dawn.common.exceptions.BaseExceptions; +import edu.umd.dawn.common.exceptions.DawnException; +import java.util.Optional; +import org.bson.Document; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +@Component +public class UserAuthService { + + private MongoClient mongoClient; + private MongoDatabase mongoDatabase; + + @Value("${common.annotation.mongodb.uri}") + private String uri; + + @Value("${common.annotation.mongodb.database}") + private String db; + + protected MongoDatabase getDatabase() { + if (mongoClient == null) { + mongoClient = MongoClients.create(uri); + } + if (mongoDatabase == null) { + mongoDatabase = mongoClient.getDatabase(db); + } + return mongoDatabase; + } + + public Role getUserRole(String userId) { + MongoDatabase database = getDatabase(); + MongoCollection<Document> collection = database.getCollection("users"); + + Document query = new Document(); + query.put("_id", userId); + + // should prob throw an auth type error + Document doc = Optional.of(collection.find(query).first()) + .orElseThrow(() -> new DawnException(BaseExceptions.NOT_FOUND)); + + int role = (int) doc.get("role"); + + return Role.fromInt(role); + } +} diff --git a/src/main/java/edu/umd/dawn/common/utils/OffsetLimitPageRequest.java b/src/main/java/edu/umd/dawn/common/utils/OffsetLimitPageRequest.java new file mode 100644 index 0000000000000000000000000000000000000000..525459f0059f11cc21b7caa8b36bb9bf598f3bb7 --- /dev/null +++ b/src/main/java/edu/umd/dawn/common/utils/OffsetLimitPageRequest.java @@ -0,0 +1,136 @@ +package edu.umd.dawn.common.utils; + +import edu.umd.dawn.common.exceptions.BaseExceptions; +import edu.umd.dawn.common.exceptions.DawnException; +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; + +public class OffsetLimitPageRequest implements Pageable { + + private int limit; + private long offset; + private final Sort sort; + + /** + * Creates a new {@link OffsetLimitPageRequest} with sort parameters applied. + * + * @param offset zero-based offset. + * @param limit the size of the elements to be returned. + * @param sort can be {@literal null}. + */ + public OffsetLimitPageRequest(long offset, int limit, Sort sort) { + if (offset < 0) { + throw new DawnException(BaseExceptions.INVALID_OFFSET((int) offset)); + } + + if (limit < 1) { + throw new DawnException(BaseExceptions.INVALID_LIMIT(limit)); + } + this.limit = limit; + this.offset = offset; + this.sort = sort; + } + + /** + * Creates a new {@link OffsetLimitPageRequest} with sort parameters applied. + * + * @param offset zero-based offset. + * @param limit the size of the elements to be returned. + */ + public OffsetLimitPageRequest(long offset, int limit) { + this(offset, limit, Sort.unsorted()); + } + + public static OffsetLimitPageRequest of(long offset, int limit) { + return new OffsetLimitPageRequest(offset, limit); + } + + public static OffsetLimitPageRequest of(long offset, int limit, Sort sort) { + return new OffsetLimitPageRequest(offset, limit, sort); + } + + @Override + public int getPageNumber() { + return (int) offset / limit; + } + + @Override + public int getPageSize() { + return limit; + } + + @Override + public long getOffset() { + return offset; + } + + @Override + public Sort getSort() { + return sort; + } + + @Override + public Pageable next() { + return new OffsetLimitPageRequest(getOffset() + getPageSize(), getPageSize(), getSort()); + } + + public OffsetLimitPageRequest previous() { + return hasPrevious() ? new OffsetLimitPageRequest(getOffset() - getPageSize(), getPageSize(), getSort()) : this; + } + + @Override + public Pageable previousOrFirst() { + return hasPrevious() ? previous() : first(); + } + + @Override + public Pageable first() { + return new OffsetLimitPageRequest(0, getPageSize(), getSort()); + } + + @Override + public boolean hasPrevious() { + return offset > limit; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + + if (!(o instanceof OffsetLimitPageRequest)) return false; + + OffsetLimitPageRequest that = (OffsetLimitPageRequest) o; + + return new EqualsBuilder() + .append(limit, that.limit) + .append(offset, that.offset) + .append(sort, that.sort) + .isEquals(); + } + + @Override + public int hashCode() { + return new HashCodeBuilder(17, 37) + .append(limit) + .append(offset) + .append(sort) + .toHashCode(); + } + + @Override + public String toString() { + return new ToStringBuilder(this) + .append("limit", limit) + .append("offset", offset) + .append("sort", sort) + .toString(); + } + + @Override + public Pageable withPage(int pageNumber) { + return null; + } +}