[F] Fix zlib compression happening after response commit

pull/29/head
Azalea 2024-03-28 00:58:55 -04:00
parent 3f01152a4a
commit c6190146aa
9 changed files with 64 additions and 200 deletions

View File

@ -3,7 +3,7 @@ package icu.samnyan.aqua.sega.allnet
import ext.*
import icu.samnyan.aqua.net.db.AquaNetUserRepo
import icu.samnyan.aqua.sega.util.AquaConst
import icu.samnyan.aqua.sega.util.Decoder.decodeAllNet
import icu.samnyan.aqua.sega.util.AllNetBillingDecoder.decodeAllNet
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.slf4j.Logger

View File

@ -1,7 +1,7 @@
package icu.samnyan.aqua.sega.billing
import ext.toUrl
import icu.samnyan.aqua.sega.util.Decoder.decodeBilling
import icu.samnyan.aqua.sega.util.AllNetBillingDecoder.decodeBilling
import jakarta.annotation.PostConstruct
import jakarta.servlet.http.HttpServletRequest
import org.eclipse.jetty.http.HttpVersion

View File

@ -1,68 +0,0 @@
package icu.samnyan.aqua.sega.diva.filter
import icu.samnyan.aqua.sega.general.filter.CompressRequestWrapper
import icu.samnyan.aqua.sega.general.filter.CompressResponseWrapper
import icu.samnyan.aqua.sega.util.Compression
import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.eclipse.jetty.io.EofException
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter
import java.util.*
/**
* @author samnyan (privateamusement@protonmail.com)
*/
@Component
class DivaCompressionFilter : OncePerRequestFilter() {
companion object {
val log: Logger = LoggerFactory.getLogger(DivaCompressionFilter::class.java)
}
override fun doFilterInternal(req: HttpServletRequest, resp: HttpServletResponse, chain: FilterChain) {
log.debug(">>> DIVA Incoming request: ${req.servletPath}")
log.debug("> ${req.headerNames.toList().map { it to req.getHeader(it) }}")
val encoding = req.getHeader("pragma")
val reqSrc = req.inputStream.readAllBytes()
log.debug("> Encoding: $encoding")
var reqResult: ByteArray?
if (encoding != null && encoding == "DFI") {
log.debug("> Request length (compressed): ${reqSrc.size}")
reqResult = Base64.getMimeDecoder().decode(reqSrc)
reqResult = Compression.decompress(reqResult)
log.debug("> Request length (decompressed): ${reqResult.size}")
} else {
reqResult = reqSrc
}
val requestWrapper = CompressRequestWrapper(req, reqResult)
val responseWrapper = CompressResponseWrapper(resp)
chain.doFilter(requestWrapper, responseWrapper)
val respSrc = responseWrapper.toByteArray()
log.debug(">>> DIVA Outgoing response: $respSrc")
log.debug("> Response length (uncompressed): ${respSrc.size}")
var respResult = Compression.compress(respSrc)
log.debug("> Response length (compressed): ${respResult.size}")
respResult = Base64.getMimeEncoder().encode(respResult)
resp.setContentLength(respResult.size)
resp.setHeader("pragma", "DFI")
try {
resp.outputStream.write(respResult)
} catch (e: EofException) {
log.warn("- EOF: Client closed connection when writing result :(")
}
}
override fun shouldNotFilter(request: HttpServletRequest): Boolean {
return !request.servletPath.startsWith("/g/diva")
}
}

View File

@ -1,50 +0,0 @@
package icu.samnyan.aqua.sega.general.filter;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpServletResponseWrapper;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
/**
* @author samnyan (privateamusement@protonmail.com)
*/
public class CompressResponseWrapper extends HttpServletResponseWrapper {
private final ByteArrayOutputStream output;
private ServletOutputStream filterOutput;
public CompressResponseWrapper(HttpServletResponse response) {
super(response);
output = new ByteArrayOutputStream();
}
@Override
public ServletOutputStream getOutputStream() {
if (filterOutput == null) {
filterOutput = new ServletOutputStream() {
@Override
public boolean isReady() {
return false;
}
@Override
public void setWriteListener(WriteListener writeListener) {
}
@Override
public void write(int b) {
output.write(b);
}
};
}
return filterOutput;
}
public byte[] toByteArray() {
return output.toByteArray();
}
}

View File

@ -1,13 +1,16 @@
package icu.samnyan.aqua.sega.general.filter
import ext.logger
import icu.samnyan.aqua.sega.util.Compression
import icu.samnyan.aqua.sega.util.ZLib
import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.eclipse.jetty.io.EofException
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter
import org.springframework.web.util.ContentCachingResponseWrapper
import java.util.*
/**
* @author samnyan (privateamusement@protonmail.com)
@ -16,27 +19,37 @@ import org.springframework.web.filter.OncePerRequestFilter
class CompressionFilter : OncePerRequestFilter() {
companion object {
val logger = logger()
val b64d = Base64.getMimeDecoder()
val b64e = Base64.getMimeEncoder()
}
override fun doFilterInternal(req: HttpServletRequest, resp: HttpServletResponse, chain: FilterChain) {
val isDeflate = req.getHeader("content-encoding") == "deflate"
val isDfi = req.getHeader("pragma") == "DFI"
// Decode input
val reqSrc = req.inputStream.readAllBytes().let {
if (req.getHeader("content-encoding") == "deflate") Compression.decompress(it)
if (isDeflate) ZLib.decompress(it)
else if (isDfi) ZLib.decompress(b64d.decode(it))
else it
}
val requestWrapper = CompressRequestWrapper(req, reqSrc)
val responseWrapper = CompressResponseWrapper(resp)
chain.doFilter(requestWrapper, responseWrapper)
val result = Compression.compress(responseWrapper.toByteArray())
// Handle request
val result = ContentCachingResponseWrapper(resp).run {
chain.doFilter(CompressRequestWrapper(req, reqSrc), this)
ZLib.compress(contentAsByteArray).let { if (isDfi) b64e.encode(it) else it }
}
// Write response
resp.setContentLength(result.size)
resp.contentType = "application/json; charset=utf-8"
resp.addHeader("Content-Encoding", "deflate")
if (isDfi) resp.setHeader("pragma", "DFI")
if (isDeflate) {
resp.contentType = "application/json; charset=utf-8"
resp.setHeader("content-encoding", "deflate")
}
try {
resp.outputStream.write(result)
resp.outputStream.use { it.write(result); it.flush() }
} catch (e: EofException) {
logger.warn("- EOF: Client closed connection when writing result")
}
@ -46,6 +59,5 @@ class CompressionFilter : OncePerRequestFilter() {
* Filter games that are not diva
*/
override fun shouldNotFilter(req: HttpServletRequest) =
!(req.servletPath.startsWith("/g/") && !req.servletPath.startsWith("/g/diva")
&& !req.servletPath.startsWith("/g/wacca"))
!(req.servletPath.startsWith("/g/") && !req.servletPath.startsWith("/g/wacca"))
}

View File

@ -277,8 +277,8 @@ class Maimai2ServletController(
"""{"returnCode":1,"apiName":"com.sega.maimai2servlet.api.$api"}"""
}
} catch (e: ApiException) {
logger.warn("Mai2 > $api : ${e.code} - ${e.message}")
return ResponseEntity.status(e.code).body("""{"returnCode":0,"apiName":"com.sega.maimai2servlet.api.$api","message":"${e.message?.replace("\"", "\\\"")} - ${e.code}"}""")
// It's a bad practice to return 200 ok on error, but this is what maimai does so we have to follow
return ResponseEntity.ok().body("""{"returnCode":0,"apiName":"com.sega.maimai2servlet.api.$api","message":"${e.message?.replace("\"", "\\\"")} - ${e.code}"}""")
}
}
}

View File

@ -3,17 +3,16 @@ package icu.samnyan.aqua.sega.util
import java.util.*
import kotlin.text.Charsets.UTF_8
object Decoder {
object AllNetBillingDecoder {
/**
* Decode the input byte array from Base64 MIME encoding and decompress the decoded byte array
*/
fun decode(src: ByteArray, base64: Boolean, nowrap: Boolean): Map<String, String> {
// Decode the input byte array from Base64 MIME encoding
var bytes = src
if (base64) bytes = Base64.getMimeDecoder().decode(bytes)
val bytes = if (base64) src else Base64.getMimeDecoder().decode(src)
// Decompress the decoded byte array
val output = Compression.decompress(bytes, nowrap).toString(UTF_8).trim()
val output = ZLib.decompress(bytes, nowrap).toString(UTF_8).trim()
// Split the string by '&' symbol to separate key-value pairs
return output.split("&").associate {

View File

@ -1,61 +0,0 @@
package icu.samnyan.aqua.sega.util;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
/**
* @author samnyan (privateamusement@protonmail.com)
*/
public class Compression {
public static byte[] decompress(byte[] src, boolean nowrap) {
ByteBuf result = Unpooled.buffer();
byte[] buffer = new byte[100];
Inflater decompressor = new Inflater(nowrap);
decompressor.setInput(src);
try {
while (!decompressor.finished()) {
int count = decompressor.inflate(buffer);
if (count == 0) {
break;
}
result.writeBytes(buffer, result.readerIndex(), count);
}
decompressor.end();
return ByteBufUtil.toBytes(result);
} catch (DataFormatException e) {
e.printStackTrace();
return new byte[0];
}
}
public static byte[] decompress(byte[] src) {
return decompress(src, false);
}
public static byte[] compress(byte[] src) {
ByteBuf result = Unpooled.buffer();
byte[] buffer = new byte[100];
Deflater compressor = new Deflater();
compressor.setInput(src);
compressor.finish();
while (!compressor.finished()) {
int count = compressor.deflate(buffer);
if (count == 0) {
break;
}
result.writeBytes(buffer, result.readerIndex(), count);
}
compressor.end();
return ByteBufUtil.toBytes(result);
}
}

View File

@ -0,0 +1,32 @@
package icu.samnyan.aqua.sega.util
import java.io.ByteArrayOutputStream
import java.util.zip.Deflater
import java.util.zip.Inflater
object ZLib {
fun decompress(src: ByteArray, nowrap: Boolean = false) = Inflater(nowrap).run {
val buffer = ByteArray(1024)
setInput(src)
ByteArrayOutputStream().use {
var count = -1
while (count != 0) {
count = inflate(buffer)
it.write(buffer, 0, count)
}
end()
it.toByteArray()
}
}
fun compress(src: ByteArray) = Deflater().run {
setInput(src)
finish()
val outputBuf = ByteArray(src.size * 4)
val compressedSize = deflate(outputBuf)
end()
outputBuf.copyOf(compressedSize)
}
}