contract;
use utils::utils::is_stable;
use interfaces::{data_structures::{Asset, PoolId, PoolMetadata}, mira_amm::MiraAMM};
use math::pool_math::{calculate_fee, validate_curve};
configurable {
AMM_CONTRACT_ID: ContractId = ContractId::zero(),
}
abi IBaseHook {
#[storage(read, write)]
fn hook(
pool_id: PoolId,
sender: Identity,
to: Identity,
asset_0_in: u64,
asset_1_in: u64,
asset_0_out: u64,
asset_1_out: u64,
lp_token: u64,
);
}
fn post_validate_curve(
is_stable: bool,
current_reserve_0: u64,
current_reserve_1: u64,
decimals_0: u8,
decimals_1: u8,
asset_0_in: u64,
asset_1_in: u64,
asset_0_out: u64,
asset_1_out: u64,
lp_fee: u64,
protocol_fee: u64,
) {
let asset_0_protocol_fee = calculate_fee(asset_0_in, protocol_fee);
let asset_1_protocol_fee = calculate_fee(asset_1_in, protocol_fee);
let asset_0_lp_fee = calculate_fee(asset_0_in, lp_fee);
let asset_1_lp_fee = calculate_fee(asset_1_in, lp_fee);
let reserve_0_increase = asset_0_in - asset_0_protocol_fee;
let reserve_1_increase = asset_1_in - asset_1_protocol_fee;
let previous_reserve_0 = current_reserve_0 - reserve_0_increase + asset_0_out;
let previous_reserve_1 = current_reserve_1 - reserve_1_increase + asset_1_out;
let current_reserve_without_fee_0 = current_reserve_0 - asset_0_lp_fee;
let current_reserve_without_fee_1 = current_reserve_1 - asset_1_lp_fee;
validate_curve(
is_stable,
current_reserve_without_fee_0,
current_reserve_without_fee_1,
previous_reserve_0,
previous_reserve_1,
decimals_0,
decimals_1,
);
}
fn validate_stable_pool_meta(pool: PoolMetadata) {
require(
pool.decimals_0 <= 9 && pool.decimals_1 <= 9,
"Decimals too big",
);
}
impl IBaseHook for Contract {
#[storage(read, write)]
fn hook(
pool_id: PoolId,
sender: Identity,
to: Identity,
asset_0_in: u64,
asset_1_in: u64,
asset_0_out: u64,
asset_1_out: u64,
lp_token: u64,
) {
if (lp_token == 0) {
// it's a swap
let amm = abi(MiraAMM, AMM_CONTRACT_ID.into());
let pool = amm.pool_metadata(pool_id).unwrap();
let (lp_fee_volatile, lp_fee_stable, protocol_fee_volatile, protocol_fee_stable) = amm.fees();
let (lp_fee, protocol_fee) = if is_stable(pool_id) {
(lp_fee_stable, protocol_fee_stable)
} else {
(lp_fee_volatile, protocol_fee_volatile)
};
post_validate_curve(
is_stable(pool_id),
pool.reserve_0,
pool.reserve_1,
pool.decimals_0,
pool.decimals_1,
asset_0_in,
asset_1_in,
asset_0_out,
asset_1_out,
lp_fee,
protocol_fee,
);
} else if (asset_0_out == 0 && asset_1_out == 0) {
// it's a mint
if (is_stable(pool_id)) {
let amm = abi(MiraAMM, AMM_CONTRACT_ID.into());
let pool = amm.pool_metadata(pool_id).unwrap();
validate_stable_pool_meta(pool);
}
}
}
}
fn run_test_cases(
cases: Vec<(bool, u64, u64, u64, u64, u64, u64, u8, u8)>,
fees: Option<(u64, u64)>,
) {
let (lp_fee, protocol_fee) = fees.unwrap_or((30, 0));
let mut i = 0;
while i < cases.len() {
let (is_stable, res_0, res_1, input_0, input_1, output_0, output_1, dec_0, dec_1) = cases.get(i).unwrap();
post_validate_curve(
is_stable,
res_0,
res_1,
dec_0,
dec_1,
input_0,
input_1,
output_0,
output_1,
lp_fee,
protocol_fee,
);
i = i + 1;
}
}
#[test]
fn test_post_validate_curve_volatile() {
// is_stable, res_0, res_1, input_0, input_1, output_0, output_1, dec_0, dec_1
let mut test_cases: Vec<(bool, u64, u64, u64, u64, u64, u64, u8, u8)> = Vec::new();
// volatile pool, same decimals
// 990 * 1001 (990_990) < (1000 - 1) * 1000 (999_000)
test_cases.push((false, 1000, 1000, 10, 0, 0, 1, 6, 6));
// 990 * 1009 (998_910) < (1000 - 1) * 1000 (999_000)
test_cases.push((false, 1000, 1000, 10, 0, 0, 9, 6, 6));
// 998_996 * 1002 (1_000_993_992) < (999_999 - 3) * 1001 (1_000_995_996)
test_cases.push((false, 999_999, 1001, 1003, 0, 0, 1, 2, 2));
// 998_996_989 * 1_002_000 (1_000_994_982_978_000) < (999_999_999 - 3009) * 1_001_000 (1_000_996_986_990_000)
test_cases.push((false, 999_999_999, 1_001_000, 1_003_010, 0, 0, 1000, 2, 2));
// volatile pool, different decimals
// 990 * 1001 (990_990) < (1000 - 1) * 1000 (999_000)
test_cases.push((false, 1000, 1000, 10, 0, 1, 0, 2, 8));
// 990 * 1009 (998_910) < (1000 - 1) * 1000 (999_000)
test_cases.push((false, 1000, 1000, 10, 0, 0, 9, 10, 0));
// 998_996 * 1002 (1_000_993_992) < (999_999 - 3) * 1001 (1_000_995_996)
test_cases.push((false, 999_999, 1001, 1003, 0, 0, 1, 2, 3));
// 998_996_989 * 1_002_000 (1_000_994_982_978_000) < (999_999_999 - 3009) * 1_001_000 (1_000_996_986_990_000)
test_cases.push((false, 999_999_999, 1_001_000, 1_003_010, 0, 0, 1000, 5, 4));
run_test_cases(test_cases, None);
}
#[test(should_revert)]
fn test_post_validate_curve_volatile_failure() {
// is_stable, res_0, res_1, input_0, input_1, output_0, output_1, dec_0, dec_1
let mut test_cases: Vec<(bool, u64, u64, u64, u64, u64, u64, u8, u8)> = Vec::new();
// 990 * 1009 (998_910) < (1000 - 1) * 1000 (999_000) OK
test_cases.push((false, 1000, 1000, 10, 0, 0, 9, 6, 6));
// 990 * 1010 (999_900) < (1000 - 1) * 1000 (999_000) VIOLATION
test_cases.push((false, 1000, 1000, 10, 0, 0, 10, 6, 6));
run_test_cases(test_cases, None);
}
#[test]
fn test_protocol_fee_calculation() {
// is_stable, res_0, res_1, input_0, input_1, output_0, output_1, dec_0, dec_1
let mut test_cases: Vec<(bool, u64, u64, u64, u64, u64, u64, u8, u8)> = Vec::new();
// current reserves: 10000, 10000
// previous reserves: 9010, 11098, 10 - protocol fee
test_cases.push((false, 10000, 10000, 1000, 0, 0, 1098, 6, 6));
run_test_cases(test_cases, Some((0, 100))); // 1% protocol fee
}
fn build_pool_meta(decimals_0: u8, decimals_1: u8) -> PoolMetadata {
PoolMetadata {
reserve_0: 0,
reserve_1: 0,
liquidity: Asset::new(AssetId::default(), 0),
decimals_0,
decimals_1,
}
}
#[test]
fn test_stable_pool_validation() {
validate_stable_pool_meta(build_pool_meta(0, 0));
validate_stable_pool_meta(build_pool_meta(0, 9));
validate_stable_pool_meta(build_pool_meta(9, 0));
validate_stable_pool_meta(build_pool_meta(9, 9));
validate_stable_pool_meta(build_pool_meta(5, 7));
validate_stable_pool_meta(build_pool_meta(2, 8));
validate_stable_pool_meta(build_pool_meta(8, 2));
}
#[test(should_revert)]
fn test_stable_pool_validation_failure_decimals_0() {
validate_stable_pool_meta(build_pool_meta(10, 0));
}
#[test(should_revert)]
fn test_stable_pool_validation_failure_decimals_1() {
validate_stable_pool_meta(build_pool_meta(0, 10));
}