data handler improves, written test for nfproxy, new option on parsing fail
This commit is contained in:
@@ -94,7 +94,6 @@ This handler will be called twice: one for the request headers and one for the r
|
||||
- headers: dict - The headers of the request
|
||||
- user_agent: str - The user agent of the request
|
||||
- content_encoding: str - The content encoding of the request
|
||||
- has_begun: bool - It's true if the request has begun
|
||||
- body: bytes - The body of the request
|
||||
- headers_complete: bool - It's true if the headers are complete
|
||||
- message_complete: bool - It's true if the message is complete
|
||||
@@ -122,7 +121,6 @@ This handler will be called twice: one for the response headers and one for the
|
||||
- headers: dict - The headers of the response
|
||||
- user_agent: str - The user agent of the response
|
||||
- content_encoding: str - The content encoding of the response
|
||||
- has_begun: bool - It's true if the response has begun
|
||||
- body: bytes - The body of the response
|
||||
- headers_complete: bool - It's true if the headers are complete
|
||||
- message_complete: bool - It's true if the message is complete
|
||||
|
||||
@@ -75,12 +75,9 @@ def handle_packet(glob: dict) -> None:
|
||||
|
||||
cache_call = {} # Cache of the data handler calls
|
||||
cache_call[RawPacket] = internal_data.current_pkt
|
||||
|
||||
final_result = Action.ACCEPT
|
||||
|
||||
result = PacketHandlerResult(glob)
|
||||
|
||||
func_name = None
|
||||
mangled_packet = None
|
||||
for filter in internal_data.filter_call_info:
|
||||
final_params = []
|
||||
skip_call = False
|
||||
@@ -116,24 +113,37 @@ def handle_packet(glob: dict) -> None:
|
||||
if skip_call:
|
||||
continue
|
||||
|
||||
res = context_call(glob, filter.func, *final_params)
|
||||
|
||||
if res is None:
|
||||
continue #ACCEPTED
|
||||
if not isinstance(res, Action):
|
||||
raise Exception(f"Invalid return type {type(res)} for function {filter.name}")
|
||||
if res == Action.MANGLE:
|
||||
mangled_packet = internal_data.current_pkt.raw_packet
|
||||
if res != Action.ACCEPT:
|
||||
func_name = filter.name
|
||||
final_result = res
|
||||
break
|
||||
# Create an iterator with all the calls to be done
|
||||
def try_to_call(params:list):
|
||||
is_base_call = True
|
||||
for i in range(len(params)):
|
||||
if isinstance(params[i], list):
|
||||
new_params = params.copy()
|
||||
for ele in params[i]:
|
||||
new_params[i] = ele
|
||||
for ele in try_to_call(new_params):
|
||||
yield ele
|
||||
is_base_call = False
|
||||
break
|
||||
if is_base_call:
|
||||
yield context_call(glob, filter.func, *params)
|
||||
|
||||
for res in try_to_call(final_params):
|
||||
if res is None:
|
||||
continue #ACCEPTED
|
||||
if not isinstance(res, Action):
|
||||
raise Exception(f"Invalid return type {type(res)} for function {filter.name}")
|
||||
if res == Action.MANGLE:
|
||||
result.matched_by = filter.name
|
||||
result.mangled_packet = internal_data.current_pkt.raw_packet
|
||||
result.action = Action.MANGLE
|
||||
elif res != Action.ACCEPT:
|
||||
result.matched_by = filter.name
|
||||
result.action = res
|
||||
result.mangled_packet = None
|
||||
return result.set_result()
|
||||
|
||||
result.action = final_result
|
||||
result.matched_by = func_name
|
||||
result.mangled_packet = mangled_packet
|
||||
|
||||
return result.set_result()
|
||||
return result.set_result() # Will be MANGLE or ACCEPT
|
||||
|
||||
|
||||
def compile(glob:dict) -> None:
|
||||
@@ -148,13 +158,12 @@ def compile(glob:dict) -> None:
|
||||
|
||||
if "FGEX_STREAM_MAX_SIZE" in glob and int(glob["FGEX_STREAM_MAX_SIZE"]) > 0:
|
||||
internal_data.stream_max_size = int(glob["FGEX_STREAM_MAX_SIZE"])
|
||||
else:
|
||||
internal_data.stream_max_size = 1*8e20 # 1MB default value
|
||||
|
||||
if "FGEX_FULL_STREAM_ACTION" in glob and isinstance(glob["FGEX_FULL_STREAM_ACTION"], FullStreamAction):
|
||||
internal_data.full_stream_action = glob["FGEX_FULL_STREAM_ACTION"]
|
||||
else:
|
||||
internal_data.full_stream_action = FullStreamAction.FLUSH
|
||||
|
||||
if "FGEX_INVALID_ENCODING_ACTION" in glob and isinstance(glob["FGEX_INVALID_ENCODING_ACTION"], Action):
|
||||
internal_data.invalid_encoding_action = glob["FGEX_INVALID_ENCODING_ACTION"]
|
||||
|
||||
PacketHandlerResult(glob).reset_result()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from firegex.nfproxy.internals.models import FilterHandler
|
||||
from firegex.nfproxy.internals.models import FullStreamAction
|
||||
from firegex.nfproxy.internals.models import FullStreamAction, ExceptionAction
|
||||
|
||||
class RawPacket:
|
||||
"class rapresentation of the nfqueue packet sent in python context by the c++ core"
|
||||
@@ -120,23 +120,39 @@ class DataStreamCtx:
|
||||
@property
|
||||
def stream_max_size(self) -> int:
|
||||
if "stream_max_size" not in self.__data.keys():
|
||||
self.__data["stream_max_size"] = 1*8e20
|
||||
self.__data["stream_max_size"] = 1*8e20 # 1MB default value
|
||||
return self.__data.get("stream_max_size")
|
||||
|
||||
@stream_max_size.setter
|
||||
def stream_max_size(self, v: int):
|
||||
if not isinstance(v, int):
|
||||
raise Exception("Invalid data type, data MUST be of type int")
|
||||
self.__data["stream_max_size"] = v
|
||||
|
||||
@property
|
||||
def full_stream_action(self) -> FullStreamAction:
|
||||
if "full_stream_action" not in self.__data.keys():
|
||||
self.__data["full_stream_action"] = "flush"
|
||||
self.__data["full_stream_action"] = FullStreamAction.FLUSH
|
||||
return self.__data.get("full_stream_action")
|
||||
|
||||
@full_stream_action.setter
|
||||
def full_stream_action(self, v: FullStreamAction):
|
||||
if not isinstance(v, FullStreamAction):
|
||||
raise Exception("Invalid data type, data MUST be of type FullStreamAction")
|
||||
self.__data["full_stream_action"] = v
|
||||
|
||||
@property
|
||||
def invalid_encoding_action(self) -> ExceptionAction:
|
||||
if "invalid_encoding_action" not in self.__data.keys():
|
||||
self.__data["invalid_encoding_action"] = ExceptionAction.REJECT
|
||||
return self.__data.get("invalid_encoding_action")
|
||||
|
||||
@invalid_encoding_action.setter
|
||||
def invalid_encoding_action(self, v: ExceptionAction):
|
||||
if not isinstance(v, ExceptionAction):
|
||||
raise Exception("Invalid data type, data MUST be of type ExceptionAction")
|
||||
self.__data["invalid_encoding_action"] = v
|
||||
|
||||
@property
|
||||
def data_handler_context(self) -> dict:
|
||||
if "data_handler_context" not in self.__data.keys():
|
||||
|
||||
@@ -13,3 +13,4 @@ class RejectConnection(Exception):
|
||||
|
||||
class StreamFullReject(Exception):
|
||||
"raise this exception if you want to reject the connection due to full stream"
|
||||
|
||||
|
||||
@@ -8,6 +8,13 @@ class Action(Enum):
|
||||
REJECT = 2
|
||||
MANGLE = 3
|
||||
|
||||
class ExceptionAction(Enum):
|
||||
"""Action to be taken by the filter when an exception occurs (used in some cases)"""
|
||||
ACCEPT = 0
|
||||
DROP = 1
|
||||
REJECT = 2
|
||||
NOACTION = 3
|
||||
|
||||
class FullStreamAction(Enum):
|
||||
"""Action to be taken by the filter when the stream is full"""
|
||||
FLUSH = 0
|
||||
@@ -40,5 +47,3 @@ class PacketHandlerResult:
|
||||
|
||||
def reset_result(self) -> None:
|
||||
self.glob["__firegex_pyfilter_result"] = None
|
||||
|
||||
|
||||
|
||||
@@ -1,101 +1,143 @@
|
||||
import pyllhttp
|
||||
from firegex.nfproxy.internals.exceptions import NotReadyToRun
|
||||
from firegex.nfproxy.internals.data import DataStreamCtx
|
||||
from firegex.nfproxy.internals.exceptions import StreamFullDrop, StreamFullReject
|
||||
from firegex.nfproxy.internals.models import FullStreamAction
|
||||
from firegex.nfproxy.internals.exceptions import StreamFullDrop, StreamFullReject, RejectConnection, DropPacket
|
||||
from firegex.nfproxy.internals.models import FullStreamAction, ExceptionAction
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
from typing import Type
|
||||
|
||||
@dataclass
|
||||
class InternalHTTPMessage:
|
||||
"""Internal class to handle HTTP messages"""
|
||||
url: str|None = field(default=None)
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
lheaders: dict[str, str] = field(default_factory=dict) # lowercase copy of the headers
|
||||
body: bytes|None = field(default=None)
|
||||
headers_complete: bool = field(default=False)
|
||||
message_complete: bool = field(default=False)
|
||||
status: str|None = field(default=None)
|
||||
total_size: int = field(default=0)
|
||||
user_agent: str = field(default_factory=str)
|
||||
content_encoding: str = field(default=str)
|
||||
content_type: str = field(default=str)
|
||||
keep_alive: bool = field(default=False)
|
||||
should_upgrade: bool = field(default=False)
|
||||
http_version: str = field(default=str)
|
||||
method: str = field(default=str)
|
||||
content_length: int = field(default=0)
|
||||
stream: bytes = field(default_factory=bytes)
|
||||
|
||||
@dataclass
|
||||
class InternalHttpBuffer:
|
||||
"""Internal class to handle HTTP messages"""
|
||||
_url_buffer: bytes = field(default_factory=bytes)
|
||||
_header_fields: dict[bytes, bytes] = field(default_factory=dict)
|
||||
_body_buffer: bytes = field(default_factory=bytes)
|
||||
_status_buffer: bytes = field(default_factory=bytes)
|
||||
_current_header_field: bytes = field(default_factory=bytes)
|
||||
_current_header_value: bytes = field(default_factory=bytes)
|
||||
|
||||
class InternalCallbackHandler():
|
||||
|
||||
url: str|None = None
|
||||
_url_buffer: bytes = b""
|
||||
headers: dict[str, str] = {}
|
||||
lheaders: dict[str, str] = {} # Lowercase headers
|
||||
_header_fields: dict[bytes, bytes] = {}
|
||||
has_begun: bool = False
|
||||
body: bytes = None
|
||||
_body_buffer: bytes = b""
|
||||
headers_complete: bool = False
|
||||
message_complete: bool = False
|
||||
status: str|None = None
|
||||
_status_buffer: bytes = b""
|
||||
_current_header_field = b""
|
||||
_current_header_value = b""
|
||||
_save_body = True
|
||||
total_size = 0
|
||||
buffers = InternalHttpBuffer()
|
||||
msg = InternalHTTPMessage()
|
||||
save_body = True
|
||||
raised_error = False
|
||||
has_begun = False
|
||||
messages: deque[InternalHTTPMessage] = deque()
|
||||
|
||||
def reset_data(self):
|
||||
self.msg = InternalHTTPMessage()
|
||||
self.buffers = InternalHttpBuffer()
|
||||
self.messages.clear()
|
||||
|
||||
def on_message_begin(self):
|
||||
self.buffers = InternalHttpBuffer()
|
||||
self.msg = InternalHTTPMessage()
|
||||
self.has_begun = True
|
||||
|
||||
def on_url(self, url):
|
||||
self.total_size += len(url)
|
||||
self._url_buffer += url
|
||||
self.buffers._url_buffer += url
|
||||
self.msg.total_size += len(url)
|
||||
|
||||
def on_url_complete(self):
|
||||
self.url = self._url_buffer.decode(errors="ignore")
|
||||
self._url_buffer = None
|
||||
self.msg.url = self.buffers._url_buffer.decode(errors="ignore")
|
||||
self.buffers._url_buffer = b""
|
||||
|
||||
def on_status(self, status: bytes):
|
||||
self.msg.total_size += len(status)
|
||||
self.buffers._status_buffer += status
|
||||
|
||||
def on_status_complete(self):
|
||||
self.msg.status = self.buffers._status_buffer.decode(errors="ignore")
|
||||
self.buffers._status_buffer = b""
|
||||
|
||||
def on_header_field(self, field):
|
||||
self.total_size += len(field)
|
||||
self._current_header_field += field
|
||||
self.msg.total_size += len(field)
|
||||
self.buffers._current_header_field += field
|
||||
|
||||
def on_header_field_complete(self):
|
||||
self._current_header_field = self._current_header_field
|
||||
pass # Nothing to do
|
||||
|
||||
def on_header_value(self, value):
|
||||
self.total_size += len(value)
|
||||
self._current_header_value += value
|
||||
self.msg.total_size += len(value)
|
||||
self.buffers._current_header_value += value
|
||||
|
||||
def on_header_value_complete(self):
|
||||
if self._current_header_value is not None and self._current_header_field is not None:
|
||||
self._header_fields[self._current_header_field.decode(errors="ignore")] = self._current_header_value.decode(errors="ignore")
|
||||
self._current_header_field = b""
|
||||
self._current_header_value = b""
|
||||
if self.buffers._current_header_field:
|
||||
self.buffers._header_fields[self.buffers._current_header_field.decode(errors="ignore")] = self.buffers._current_header_value.decode(errors="ignore")
|
||||
self.buffers._current_header_field = b""
|
||||
self.buffers._current_header_value = b""
|
||||
|
||||
def on_headers_complete(self):
|
||||
self.headers_complete = True
|
||||
self.headers = self._header_fields
|
||||
self.lheaders = {k.lower(): v for k, v in self._header_fields.items()}
|
||||
self._header_fields = {}
|
||||
self._current_header_field = b""
|
||||
self._current_header_value = b""
|
||||
self.msg.headers = self.buffers._header_fields
|
||||
self.msg.lheaders = {k.lower(): v for k, v in self.buffers._header_fields.items()}
|
||||
self.buffers._header_fields = {}
|
||||
self.buffers._current_header_field = b""
|
||||
self.buffers._current_header_value = b""
|
||||
self.msg.headers_complete = True
|
||||
self.msg.method = self.method_parsed
|
||||
self.msg.content_length = self.content_length_parsed
|
||||
self.msg.should_upgrade = self.should_upgrade
|
||||
self.msg.keep_alive = self.keep_alive
|
||||
self.msg.http_version = self.http_version
|
||||
self.msg.content_type = self.content_type
|
||||
self.msg.content_encoding = self.content_encoding
|
||||
self.msg.user_agent = self.user_agent
|
||||
|
||||
def on_body(self, body: bytes):
|
||||
if self._save_body:
|
||||
self.total_size += len(body)
|
||||
self._body_buffer += body
|
||||
if self.save_body:
|
||||
self.msg.total_size += len(body)
|
||||
self.buffers._body_buffer += body
|
||||
|
||||
def on_message_complete(self):
|
||||
self.body = self._body_buffer
|
||||
self._body_buffer = b""
|
||||
self.msg.body = self.buffers._body_buffer
|
||||
self.buffers._body_buffer = b""
|
||||
try:
|
||||
if "gzip" in self.content_encoding.lower():
|
||||
import gzip
|
||||
import io
|
||||
with gzip.GzipFile(fileobj=io.BytesIO(self.body)) as f:
|
||||
self.body = f.read()
|
||||
with gzip.GzipFile(fileobj=io.BytesIO(self.msg.body)) as f:
|
||||
self.msg.body = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error decompressing gzip: {e}: skipping", flush=True)
|
||||
self.message_complete = True
|
||||
|
||||
def on_status(self, status: bytes):
|
||||
self.total_size += len(status)
|
||||
self._status_buffer += status
|
||||
|
||||
def on_status_complete(self):
|
||||
self.status = self._status_buffer.decode(errors="ignore")
|
||||
self._status_buffer = b""
|
||||
self.msg.message_complete = True
|
||||
self.has_begun = False
|
||||
if not self._packet_to_stream():
|
||||
self.messages.append(self.msg)
|
||||
|
||||
@property
|
||||
def user_agent(self) -> str:
|
||||
return self.lheaders.get("user-agent", "")
|
||||
return self.msg.lheaders.get("user-agent", "")
|
||||
|
||||
@property
|
||||
def content_encoding(self) -> str:
|
||||
return self.lheaders.get("content-encoding", "")
|
||||
return self.msg.lheaders.get("content-encoding", "")
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.lheaders.get("content-type", "")
|
||||
return self.msg.lheaders.get("content-type", "")
|
||||
|
||||
@property
|
||||
def keep_alive(self) -> bool:
|
||||
@@ -107,16 +149,49 @@ class InternalCallbackHandler():
|
||||
|
||||
@property
|
||||
def http_version(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
if self.major and self.minor:
|
||||
return f"{self.major}.{self.minor}"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def method_parsed(self) -> str:
|
||||
return self.method.decode(errors="ignore")
|
||||
return self.method
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
"""Total size used by the parser"""
|
||||
tot = self.msg.total_size
|
||||
for msg in self.messages:
|
||||
tot += msg.total_size
|
||||
return tot
|
||||
|
||||
@property
|
||||
def content_length_parsed(self) -> int:
|
||||
return self.content_length
|
||||
|
||||
def _packet_to_stream(self):
|
||||
return self.should_upgrade and self.save_body
|
||||
|
||||
def parse_data(self, data: bytes):
|
||||
if self._packet_to_stream(): # This is a websocket upgrade!
|
||||
self.msg.message_complete = True # The message is complete but becomed a stream, so need to be called every time a new packet is received
|
||||
self.msg.total_size += len(data)
|
||||
self.msg.stream += data #buffering stream
|
||||
else:
|
||||
try:
|
||||
self.execute(data)
|
||||
except Exception as e:
|
||||
self.raised_error = True
|
||||
print(f"Error parsing HTTP packet: {e} with data {data}", flush=True)
|
||||
raise e
|
||||
|
||||
def pop_message(self):
|
||||
return self.messages.popleft()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<InternalCallbackHandler msg={self.msg} buffers={self.buffers} save_body={self.save_body} raised_error={self.raised_error} has_begun={self.has_begun} messages={self.messages}>"
|
||||
|
||||
|
||||
class InternalHttpRequest(InternalCallbackHandler, pyllhttp.Request):
|
||||
def __init__(self):
|
||||
@@ -131,11 +206,15 @@ class InternalHttpResponse(InternalCallbackHandler, pyllhttp.Response):
|
||||
class InternalBasicHttpMetaClass:
|
||||
"""Internal class to handle HTTP requests and responses"""
|
||||
|
||||
def __init__(self):
|
||||
self._parser: InternalHttpRequest|InternalHttpResponse
|
||||
self._headers_were_set = False
|
||||
def __init__(self, parser: InternalHttpRequest|InternalHttpResponse, msg: InternalHTTPMessage):
|
||||
self._parser = parser
|
||||
self.stream = b""
|
||||
self.raised_error = False
|
||||
self._message: InternalHTTPMessage|None = msg
|
||||
self._contructor_hook()
|
||||
|
||||
def _contructor_hook(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
@@ -145,116 +224,98 @@ class InternalBasicHttpMetaClass:
|
||||
@property
|
||||
def url(self) -> str|None:
|
||||
"""URL of the message"""
|
||||
return self._parser.url
|
||||
return self._message.url
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""Headers of the message"""
|
||||
return self._parser.headers
|
||||
return self._message.headers
|
||||
|
||||
@property
|
||||
def user_agent(self) -> str:
|
||||
"""User agent of the message"""
|
||||
return self._parser.user_agent
|
||||
return self._message.user_agent
|
||||
|
||||
@property
|
||||
def content_encoding(self) -> str:
|
||||
"""Content encoding of the message"""
|
||||
return self._parser.content_encoding
|
||||
|
||||
@property
|
||||
def has_begun(self) -> bool:
|
||||
"""If the message has begun"""
|
||||
return self._parser.has_begun
|
||||
return self._message.content_encoding
|
||||
|
||||
@property
|
||||
def body(self) -> bytes:
|
||||
"""Body of the message"""
|
||||
return self._parser.body
|
||||
return self._message.body
|
||||
|
||||
@property
|
||||
def headers_complete(self) -> bool:
|
||||
"""If the headers are complete"""
|
||||
return self._parser.headers_complete
|
||||
return self._message.headers_complete
|
||||
|
||||
@property
|
||||
def message_complete(self) -> bool:
|
||||
"""If the message is complete"""
|
||||
return self._parser.message_complete
|
||||
return self._message.message_complete
|
||||
|
||||
@property
|
||||
def http_version(self) -> str:
|
||||
"""HTTP version of the message"""
|
||||
return self._parser.http_version
|
||||
return self._message.http_version
|
||||
|
||||
@property
|
||||
def keep_alive(self) -> bool:
|
||||
"""If the message should keep alive"""
|
||||
return self._parser.keep_alive
|
||||
return self._message.keep_alive
|
||||
|
||||
@property
|
||||
def should_upgrade(self) -> bool:
|
||||
"""If the message should upgrade"""
|
||||
return self._parser.should_upgrade
|
||||
return self._message.should_upgrade
|
||||
|
||||
@property
|
||||
def content_length(self) -> int|None:
|
||||
"""Content length of the message"""
|
||||
return self._parser.content_length_parsed
|
||||
return self._message.content_length
|
||||
|
||||
def get_header(self, header: str, default=None) -> str:
|
||||
"""Get a header from the message without caring about the case"""
|
||||
return self._parser.lheaders.get(header.lower(), default)
|
||||
return self._message.lheaders.get(header.lower(), default)
|
||||
|
||||
def _packet_to_stream(self, internal_data: DataStreamCtx):
|
||||
return self.should_upgrade and self._parser._save_body
|
||||
@staticmethod
|
||||
def _associated_parser_class() -> Type[InternalHttpRequest]|Type[InternalHttpResponse]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _fetch_current_packet(self, internal_data: DataStreamCtx):
|
||||
if self._packet_to_stream(internal_data): # This is a websocket upgrade!
|
||||
self._parser.total_size += len(internal_data.current_pkt.data)
|
||||
self.stream += internal_data.current_pkt.data
|
||||
else:
|
||||
try:
|
||||
self._parser.execute(internal_data.current_pkt.data)
|
||||
if not self._parser.message_complete and self._parser.headers_complete and len(self._parser._body_buffer) == self._parser.content_length_parsed:
|
||||
self._parser.on_message_complete()
|
||||
except Exception as e:
|
||||
self.raised_error = True
|
||||
print(f"Error parsing HTTP packet: {e} {internal_data.current_pkt}", self, flush=True)
|
||||
raise e
|
||||
|
||||
#It's called the first time if the headers are complete, and second time with body complete
|
||||
def _after_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
if self._parser.headers_complete and not self._headers_were_set:
|
||||
self._headers_were_set = True
|
||||
return True
|
||||
return self._parser.message_complete or self.should_upgrade
|
||||
|
||||
def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
@staticmethod
|
||||
def _before_fetch_callable_checks(internal_data: DataStreamCtx):
|
||||
return True
|
||||
|
||||
def _trigger_remove_data(self, internal_data: DataStreamCtx):
|
||||
return self.message_complete and not self.should_upgrade
|
||||
|
||||
@classmethod
|
||||
def _fetch_packet(cls, internal_data: DataStreamCtx):
|
||||
if internal_data.current_pkt is None or internal_data.current_pkt.is_tcp is False:
|
||||
raise NotReadyToRun()
|
||||
|
||||
datahandler:InternalBasicHttpMetaClass = internal_data.data_handler_context.get(cls, None)
|
||||
if datahandler is None or datahandler.raised_error:
|
||||
datahandler = cls()
|
||||
internal_data.data_handler_context[cls] = datahandler
|
||||
ParserType = cls._associated_parser_class()
|
||||
|
||||
if not datahandler._before_fetch_callable_checks(internal_data):
|
||||
parser = internal_data.data_handler_context.get(cls, None)
|
||||
if parser is None or parser.raised_error:
|
||||
parser: InternalHttpRequest|InternalHttpResponse = ParserType()
|
||||
internal_data.data_handler_context[cls] = parser
|
||||
|
||||
if not cls._before_fetch_callable_checks(internal_data):
|
||||
raise NotReadyToRun()
|
||||
|
||||
# Memory size managment
|
||||
if datahandler.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
|
||||
if parser.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
|
||||
match internal_data.full_stream_action:
|
||||
case FullStreamAction.FLUSH:
|
||||
datahandler = cls()
|
||||
internal_data.data_handler_context[cls] = datahandler
|
||||
# Deleting parser and re-creating it
|
||||
parser.messages.clear()
|
||||
parser.msg.total_size -= len(parser.msg.stream)
|
||||
parser.msg.stream = b""
|
||||
parser.msg.total_size -= len(parser.msg.body)
|
||||
parser.msg.body = b""
|
||||
print("[WARNING] Flushing stream", flush=True)
|
||||
if parser.total_size+len(internal_data.current_pkt.data) > internal_data.stream_max_size:
|
||||
parser.reset_data()
|
||||
case FullStreamAction.REJECT:
|
||||
raise StreamFullReject()
|
||||
case FullStreamAction.DROP:
|
||||
@@ -262,16 +323,41 @@ class InternalBasicHttpMetaClass:
|
||||
case FullStreamAction.ACCEPT:
|
||||
raise NotReadyToRun()
|
||||
|
||||
datahandler._fetch_current_packet(internal_data)
|
||||
headers_were_set = parser.msg.headers_complete
|
||||
try:
|
||||
parser.parse_data(internal_data.current_pkt.data)
|
||||
except Exception as e:
|
||||
match internal_data.invalid_encoding_action:
|
||||
case ExceptionAction.REJECT:
|
||||
raise RejectConnection()
|
||||
case ExceptionAction.DROP:
|
||||
raise DropPacket()
|
||||
case ExceptionAction.NOACTION:
|
||||
raise e
|
||||
case ExceptionAction.ACCEPT:
|
||||
raise NotReadyToRun()
|
||||
|
||||
if not datahandler._after_fetch_callable_checks(internal_data):
|
||||
messages_tosend:list[InternalHTTPMessage] = []
|
||||
for i in range(len(parser.messages)):
|
||||
messages_tosend.append(parser.pop_message())
|
||||
|
||||
if len(messages_tosend) > 0:
|
||||
headers_were_set = False # New messages completed so the current message headers were not set in this case
|
||||
|
||||
if not headers_were_set and parser.msg.headers_complete:
|
||||
messages_tosend.append(parser.msg) # Also the current message needs to be sent due to complete headers
|
||||
|
||||
if headers_were_set and parser.msg.message_complete and parser.msg.should_upgrade and parser.save_body:
|
||||
messages_tosend.append(parser.msg) # Also the current message needs to beacase a websocket stream is going on
|
||||
|
||||
messages_to_call = len(messages_tosend)
|
||||
|
||||
if messages_to_call == 0:
|
||||
raise NotReadyToRun()
|
||||
elif messages_to_call == 1:
|
||||
return cls(parser, messages_tosend[0])
|
||||
|
||||
if datahandler._trigger_remove_data(internal_data):
|
||||
if internal_data.data_handler_context.get(cls):
|
||||
del internal_data.data_handler_context[cls]
|
||||
|
||||
return datahandler
|
||||
return [cls(parser, ele) for ele in messages_tosend]
|
||||
|
||||
class HttpRequest(InternalBasicHttpMetaClass):
|
||||
"""
|
||||
@@ -279,22 +365,21 @@ class HttpRequest(InternalBasicHttpMetaClass):
|
||||
This data handler will be called twice, first with the headers complete, and second with the body complete
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# These will be used in the metaclass
|
||||
self._parser: InternalHttpRequest = InternalHttpRequest()
|
||||
self._headers_were_set = False
|
||||
@staticmethod
|
||||
def _associated_parser_class() -> Type[InternalHttpRequest]:
|
||||
return InternalHttpRequest
|
||||
|
||||
@staticmethod
|
||||
def _before_fetch_callable_checks(internal_data: DataStreamCtx):
|
||||
return internal_data.current_pkt.is_input
|
||||
|
||||
@property
|
||||
def method(self) -> bytes:
|
||||
"""Method of the request"""
|
||||
return self._parser.method_parsed
|
||||
|
||||
def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
return internal_data.current_pkt.is_input
|
||||
return self._parser.msg.method
|
||||
|
||||
def __repr__(self):
|
||||
return f"<HttpRequest method={self.method} url={self.url} headers={self.headers} body={self.body} http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} has_begun={self.has_begun} content_length={self.content_length} stream={self.stream}>"
|
||||
return f"<HttpRequest method={self.method} url={self.url} headers={self.headers} body={self.body} http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} content_length={self.content_length} stream={self.stream}>"
|
||||
|
||||
class HttpResponse(InternalBasicHttpMetaClass):
|
||||
"""
|
||||
@@ -302,40 +387,30 @@ class HttpResponse(InternalBasicHttpMetaClass):
|
||||
This data handler will be called twice, first with the headers complete, and second with the body complete
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._parser: InternalHttpResponse = InternalHttpResponse()
|
||||
self._headers_were_set = False
|
||||
@staticmethod
|
||||
def _associated_parser_class() -> Type[InternalHttpResponse]:
|
||||
return InternalHttpResponse
|
||||
|
||||
@staticmethod
|
||||
def _before_fetch_callable_checks(internal_data: DataStreamCtx):
|
||||
return not internal_data.current_pkt.is_input
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
"""Status code of the response"""
|
||||
return self._parser.status
|
||||
|
||||
def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
return not internal_data.current_pkt.is_input
|
||||
return self._parser.msg.status
|
||||
|
||||
def __repr__(self):
|
||||
return f"<HttpResponse status_code={self.status_code} url={self.url} headers={self.headers} body={self.body} http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} has_begun={self.has_begun} content_length={self.content_length} stream={self.stream}>"
|
||||
return f"<HttpResponse status_code={self.status_code} url={self.url} headers={self.headers} body={self.body} http_version={self.http_version} keep_alive={self.keep_alive} should_upgrade={self.should_upgrade} headers_complete={self.headers_complete} message_complete={self.message_complete} content_length={self.content_length} stream={self.stream}>"
|
||||
|
||||
class HttpRequestHeader(HttpRequest):
|
||||
"""
|
||||
HTTP Request Header handler
|
||||
This data handler will be called only once, the headers are complete, the body will be empty and not buffered
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._parser._save_body = False
|
||||
|
||||
def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
return internal_data.current_pkt.is_input and not self._headers_were_set
|
||||
|
||||
def _after_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
if self._parser.headers_complete and not self._headers_were_set:
|
||||
self._headers_were_set = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def _contructor_hook(self):
|
||||
self._parser.save_body = False
|
||||
|
||||
class HttpResponseHeader(HttpResponse):
|
||||
"""
|
||||
@@ -343,15 +418,5 @@ class HttpResponseHeader(HttpResponse):
|
||||
This data handler will be called only once, the headers are complete, the body will be empty and not buffered
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._parser._save_body = False
|
||||
|
||||
def _before_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
return not internal_data.current_pkt.is_input and not self._headers_were_set
|
||||
|
||||
def _after_fetch_callable_checks(self, internal_data: DataStreamCtx):
|
||||
if self._parser.headers_complete and not self._headers_were_set:
|
||||
self._headers_were_set = True
|
||||
return True
|
||||
return False
|
||||
def _contructor_hook(self):
|
||||
self._parser.save_body = False
|
||||
@@ -71,7 +71,7 @@ class TCPInputStream(InternalTCPStream):
|
||||
|
||||
TCPClientStream = TCPInputStream
|
||||
|
||||
class TCPOutputStream:
|
||||
class TCPOutputStream(InternalTCPStream):
|
||||
"""
|
||||
This datamodel will assemble the TCP output stream from the server sent data.
|
||||
The function that use this data model will be handled when:
|
||||
|
||||
Reference in New Issue
Block a user