2023-09-22 20:46:50 +02:00
import sqlite3
from fastapi import APIRouter , HTTPException
from pydantic import BaseModel
from utils . sqlite import SQLite
2023-09-26 01:17:09 +02:00
from utils import ip_parse , ip_family , socketio_emit , PortType
2023-09-22 20:46:50 +02:00
from utils . models import ResetRequest , StatusMessageModel
from modules . firewall . nftables import FiregexTables
from modules . firewall . firewall import FirewallManager
class RuleModel ( BaseModel ) :
active : bool
name : str
proto : str
ip_src : str
ip_dst : str
port_src_from : PortType
port_dst_from : PortType
port_src_to : PortType
port_dst_to : PortType
action : str
mode : str
2023-09-23 00:23:01 +02:00
2023-09-24 05:48:54 +02:00
class RuleFormAdd ( BaseModel ) :
2023-09-23 00:23:01 +02:00
rules : list [ RuleModel ]
policy : str
2023-09-24 05:48:54 +02:00
class RuleInfo ( BaseModel ) :
rules : list [ RuleModel ]
policy : str
enabled : bool
2023-09-23 00:23:01 +02:00
2023-09-22 20:46:50 +02:00
class RuleAddResponse ( BaseModel ) :
status : str | list [ dict ]
class RenameForm ( BaseModel ) :
name : str
2023-09-26 01:17:09 +02:00
class FirewallSettings ( BaseModel ) :
keep_rules : bool
allow_loopback : bool
allow_established : bool
2023-09-22 20:46:50 +02:00
app = APIRouter ( )
db = SQLite ( ' db/firewall-rules.db ' , {
' rules ' : {
' rule_id ' : ' INT PRIMARY KEY CHECK (rule_id >= 0) ' ,
' mode ' : ' VARCHAR(1) NOT NULL CHECK (mode IN ( " O " , " I " )) ' , # O = out, I = in, B = both
' name ' : ' VARCHAR(100) NOT NULL ' ,
' active ' : ' BOOLEAN NOT NULL CHECK (active IN (0, 1)) ' ,
' proto ' : ' VARCHAR(3) NOT NULL CHECK (proto IN ( " tcp " , " udp " , " any " )) ' ,
' ip_src ' : ' VARCHAR(100) NOT NULL ' ,
' port_src_from ' : ' INT CHECK(port_src_from > 0 and port_src_from < 65536) ' ,
' port_src_to ' : ' INT CHECK(port_src_to > 0 and port_src_to < 65536 and port_src_from <= port_src_to) ' ,
' ip_dst ' : ' VARCHAR(100) NOT NULL ' ,
' port_dst_from ' : ' INT CHECK(port_dst_from > 0 and port_dst_from < 65536) ' ,
' port_dst_to ' : ' INT CHECK(port_dst_to > 0 and port_dst_to < 65536 and port_dst_from <= port_dst_to) ' ,
' action ' : ' VARCHAR(10) NOT NULL CHECK (action IN ( " accept " , " drop " , " reject " )) ' ,
} ,
' QUERY ' : [
2023-09-22 20:58:35 +02:00
" CREATE UNIQUE INDEX IF NOT EXISTS unique_rules ON rules (proto, ip_src, ip_dst, port_src_from, port_src_to, port_dst_from, port_dst_to, mode); "
2023-09-22 20:46:50 +02:00
]
} )
2023-09-23 00:23:01 +02:00
firewall = FirewallManager ( db )
2023-09-22 20:46:50 +02:00
async def reset ( params : ResetRequest ) :
if not params . delete :
db . backup ( )
await firewall . close ( )
FiregexTables ( ) . reset ( )
if params . delete :
db . delete ( )
db . init ( )
else :
db . restore ( )
await firewall . init ( )
async def startup ( ) :
db . init ( )
await firewall . init ( )
async def shutdown ( ) :
2023-09-26 01:17:09 +02:00
keep_rules = firewall . keep_rules
2023-09-22 20:46:50 +02:00
db . backup ( )
2023-09-26 01:17:09 +02:00
if not keep_rules :
await firewall . close ( )
2023-09-22 20:46:50 +02:00
db . disconnect ( )
db . restore ( )
2023-09-26 01:17:09 +02:00
async def refresh_frontend ( additional : list [ str ] = [ ] ) :
await socketio_emit ( [ " firewall " ] + additional )
2023-09-22 20:46:50 +02:00
async def apply_changes ( ) :
await firewall . reload ( )
await refresh_frontend ( )
return { ' status ' : ' ok ' }
2023-09-26 01:17:09 +02:00
@app.get ( " /settings " , response_model = FirewallSettings )
async def get_settings ( ) :
""" Get the firewall settings """
return {
" keep_rules " : firewall . keep_rules ,
" allow_loopback " : firewall . allow_loopback ,
" allow_established " : firewall . allow_established
}
@app.post ( " /settings/set " , response_model = StatusMessageModel )
async def set_settings ( form : FirewallSettings ) :
""" Set the firewall settings """
firewall . keep_rules = form . keep_rules
firewall . allow_loopback = form . allow_loopback
firewall . allow_established = form . allow_established
return { ' status ' : ' ok ' }
2023-09-24 05:48:54 +02:00
@app.get ( ' /rules ' , response_model = RuleInfo )
2023-09-22 20:46:50 +02:00
async def get_rule_list ( ) :
""" Get the list of existent firegex rules """
2023-09-23 00:23:01 +02:00
return {
2023-09-26 17:24:04 +02:00
" policy " : firewall . policy ,
2023-09-24 05:48:54 +02:00
" rules " : db . query ( " SELECT active, name, proto, ip_src, ip_dst, port_src_from, port_dst_from, port_src_to, port_dst_to, action, mode FROM rules ORDER BY rule_id; " ) ,
2023-09-26 17:24:04 +02:00
" enabled " : firewall . enabled
2023-09-23 00:23:01 +02:00
}
2023-09-24 05:48:54 +02:00
@app.get ( ' /enable ' , response_model = StatusMessageModel )
async def enable_firewall ( ) :
""" Request enabling the firewall """
2023-09-26 17:24:04 +02:00
firewall . enabled = True
2023-09-24 05:48:54 +02:00
return await apply_changes ( )
@app.get ( ' /disable ' , response_model = StatusMessageModel )
async def disable_firewall ( ) :
""" Request disabling the firewall """
2023-09-26 17:24:04 +02:00
firewall . enabled = False
2023-09-24 05:48:54 +02:00
return await apply_changes ( )
2023-09-22 20:46:50 +02:00
def parse_and_check_rule ( rule : RuleModel ) :
2023-09-24 19:10:32 +02:00
if rule . ip_src . lower ( ) . strip ( ) == " any " or rule . ip_dst . lower ( ) . split ( ) == " any " :
rule . ip_dst = rule . ip_src = " any "
else :
try :
rule . ip_src = ip_parse ( rule . ip_src )
rule . ip_dst = ip_parse ( rule . ip_dst )
except ValueError :
2023-09-25 18:10:12 +02:00
raise HTTPException ( status_code = 400 , detail = " Invalid address " )
if ip_family ( rule . ip_dst ) != ip_family ( rule . ip_src ) :
raise HTTPException ( status_code = 400 , detail = " Destination and source addresses must be of the same family " )
2023-09-22 20:46:50 +02:00
rule . port_dst_from , rule . port_dst_to = min ( rule . port_dst_from , rule . port_dst_to ) , max ( rule . port_dst_from , rule . port_dst_to )
rule . port_src_from , rule . port_src_to = min ( rule . port_src_from , rule . port_src_to ) , max ( rule . port_src_from , rule . port_src_to )
2023-09-25 18:10:12 +02:00
2023-09-22 20:46:50 +02:00
if rule . proto not in [ " tcp " , " udp " , " any " ] :
2023-09-25 18:10:12 +02:00
raise HTTPException ( status_code = 400 , detail = " Invalid protocol " )
2023-09-22 20:46:50 +02:00
if rule . action not in [ " accept " , " drop " , " reject " ] :
2023-09-25 18:10:12 +02:00
raise HTTPException ( status_code = 400 , detail = " Invalid action " )
2023-09-22 20:46:50 +02:00
return rule
@app.post ( ' /rules/set ' , response_model = RuleAddResponse )
2023-09-24 05:48:54 +02:00
async def add_new_service ( form : RuleFormAdd ) :
2023-09-22 20:46:50 +02:00
""" Add a new service """
2023-09-23 00:23:01 +02:00
if form . policy not in [ " accept " , " drop " , " reject " ] :
2023-09-25 18:10:12 +02:00
raise HTTPException ( status_code = 400 , detail = " Invalid policy " )
2023-09-23 00:23:01 +02:00
rules = [ parse_and_check_rule ( ele ) for ele in form . rules ]
errors = [ ( { " rule " : i } | ele ) for i , ele in enumerate ( rules ) if isinstance ( ele , dict ) ]
2023-09-22 20:46:50 +02:00
if len ( errors ) > 0 :
return { ' status ' : errors }
try :
db . queries ( [ " DELETE FROM rules " ] +
[ ( """
INSERT INTO rules (
rule_id , active , name ,
proto ,
ip_src , ip_dst ,
port_src_from , port_dst_from ,
port_src_to , port_dst_to ,
action , mode
) VALUES ( ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? ) """ ,
rid , ele . active , ele . name ,
ele . proto ,
ele . ip_src , ele . ip_dst ,
ele . port_src_from , ele . port_dst_from ,
ele . port_src_to , ele . port_dst_to ,
ele . action , ele . mode
2023-09-23 00:23:01 +02:00
) for rid , ele in enumerate ( rules ) ]
2023-09-22 20:46:50 +02:00
)
2023-09-26 17:24:04 +02:00
firewall . policy = form . policy
2023-09-22 20:46:50 +02:00
except sqlite3 . IntegrityError :
2023-09-25 18:10:12 +02:00
raise HTTPException ( status_code = 400 , detail = " Error saving the rules: maybe there are duplicated rules " )
2023-09-22 20:46:50 +02:00
return await apply_changes ( )