From bbf7f664c7c40a76aa9c54f78518c5b3ea04c082 Mon Sep 17 00:00:00 2001 From: Steam Deck User <0xlomonoshka@gmail.com> Date: Fri, 8 Dec 2023 15:08:56 +0400 Subject: [PATCH] add useExchageBalance function, add balance and allowance validation --- package.json | 2 +- src/Unit/Exchange/generateSwapCalldata.ts | 110 ++++++++++++++++------ src/utils/getBalance.ts | 29 +++++- 3 files changed, 108 insertions(+), 33 deletions(-) diff --git a/package.json b/package.json index 20109f3..d31f694 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@orionprotocol/sdk", - "version": "0.20.28", + "version": "0.20.29-rc0", "description": "Orion Protocol SDK", "main": "./lib/index.cjs", "module": "./lib/index.js", diff --git a/src/Unit/Exchange/generateSwapCalldata.ts b/src/Unit/Exchange/generateSwapCalldata.ts index 24ed183..986a288 100644 --- a/src/Unit/Exchange/generateSwapCalldata.ts +++ b/src/Unit/Exchange/generateSwapCalldata.ts @@ -18,7 +18,7 @@ import { generateCurveStableSwapCall } from "./callGenerators/curve.js"; import type { SingleSwap } from "../../types.js"; import { addressLikeToString } from "../../utils/addressLikeToString.js"; import { generateUnwrapAndTransferCall, generateWrapAndTransferCall } from "./callGenerators/weth.js"; -import { getExchangeBalance, getWalletBalance } from "../../utils/getBalance.js"; +import { getExchangeAllowance, getExchangeBalance, getTotalBalance, getWalletBalance } from "../../utils/getBalance.js"; export type Factory = "UniswapV2" | "UniswapV3" | "Curve" | "OrionV2" | "OrionV3"; @@ -109,7 +109,7 @@ export async function generateSwapCalldata({ const swapExecutorContractAddress = await addressLikeToString(swapExecutorContractAddressLike); let path = SafeArray.from(arrayLikePath); - const { factory, assetIn: srcToken } = path.first(); + const { assetIn: srcToken } = path.first(); const { assetOut: dstToken } = path.last(); let swapDescription: LibValidator.SwapDescriptionStruct = { @@ -128,15 +128,51 @@ export async function generateSwapCalldata({ if (singleSwap.assetOut == ethers.ZeroAddress) singleSwap.assetOut = wethAddress; return singleSwap; }); + + let calls: BytesLike[]; + ({ swapDescription, calls } = await processSwaps( + swapDescription, + path, + amountNativeDecimals, + wethAddress, + swapExecutorContractAddress, + curveRegistryAddress, + provider + )); + const calldata = generateCalls(calls); - const isSingleFactorySwap = path.every((singleSwap) => singleSwap.factory === factory); + const { useExchangeBalance, additionalTransferAmount } = await shouldUseExchangeBalance( + srcToken, + initiatorAddress, + exchangeContractAddress, + amountNativeDecimals, + provider + ); + if (useExchangeBalance) { + swapDescription.flags = 1n << 255n; + } + + return { swapDescription, calldata, value: additionalTransferAmount }; +} + +async function processSwaps( + swapDescription: LibValidator.SwapDescriptionStruct, + path: SafeArray, + amount: BigNumberish, + wethAddress: string, + swapExecutorContractAddress: string, + curveRegistryAddress: string, + provider: JsonRpcProvider +) { + const { factory: firstSwapFactory } = path.first(); + const isSingleFactorySwap = path.every((singleSwap) => singleSwap.factory === firstSwapFactory); let calls: BytesLike[]; if (isSingleFactorySwap) { ({ swapDescription, calls } = await processSingleFactorySwaps( - factory, + firstSwapFactory, swapDescription, path, - amountNativeDecimals, + amount, swapExecutorContractAddress, curveRegistryAddress, provider @@ -145,41 +181,20 @@ export async function generateSwapCalldata({ ({ swapDescription, calls } = await processMultiFactorySwaps( swapDescription, path, - amountNativeDecimals, + amount, swapExecutorContractAddress, curveRegistryAddress, provider )); } - ({ swapDescription, calls } = await wrapOrUnwrapIfNeeded( - amountNativeDecimals, + amount, swapDescription, calls, swapExecutorContractAddress, wethAddress )); - const calldata = generateCalls(calls); - - const initiatorWalletBalance = await getWalletBalance(srcToken, initiatorAddress, provider); - const initiatorExchangeBalance = await getExchangeBalance( - srcToken, - initiatorAddress, - exchangeContractAddress, - provider, - true - ); - const useExchangeBalance = - initiatorExchangeBalance !== 0n && (srcToken === ZeroAddress || initiatorWalletBalance < amountNativeDecimals); - if (useExchangeBalance) { - swapDescription.flags = 1n << 255n; - } - let value = 0n; - if (srcToken === ZeroAddress && initiatorExchangeBalance < amountNativeDecimals) { - value = amountNativeDecimals - initiatorExchangeBalance; - } - - return { swapDescription, calldata, value }; + return { swapDescription, calls }; } async function processSingleFactorySwaps( @@ -313,3 +328,40 @@ async function wrapOrUnwrapIfNeeded( } return { swapDescription, calls }; } + +async function shouldUseExchangeBalance( + srcToken: AddressLike, + initiatorAddress: AddressLike, + exchangeContractAddress: AddressLike, + amount: bigint, + provider: ethers.provider +) { + const { walletBalance, exchangeBalance } = await getTotalBalance( + srcToken, + initiatorAddress, + exchangeContractAddress, + provider + ); + const exchangeAllowance = await getExchangeAllowance(srcToken, initiatorAddress, exchangeContractAddress, provider); + + if (walletBalance + exchangeBalance < amount) { + throw new Error( + `Not enough balance to make swap, totalBalance - ${walletBalance + exchangeBalance} swapAmount - ${amount}` + ); + } + let useExchangeBalance = true; + let additionalTransferAmount = 0n; + + if (exchangeBalance == 0n) { + useExchangeBalance = false; + additionalTransferAmount = amount; + } else { + additionalTransferAmount = exchangeBalance >= amount ? 0n : amount - exchangeBalance; + if (additionalTransferAmount > exchangeAllowance) { + throw new Error( + `Not enough allowance to make swap, allowance - ${exchangeAllowance} needed allowance - ${additionalTransferAmount}` + ); + } + } + return { useExchangeBalance, additionalTransferAmount }; +} diff --git a/src/utils/getBalance.ts b/src/utils/getBalance.ts index c139574..c863f65 100644 --- a/src/utils/getBalance.ts +++ b/src/utils/getBalance.ts @@ -58,7 +58,7 @@ async function getExchangeBalanceERC20( exchangeAddress = await addressLikeToString(exchangeAddress); tokenAddress = await addressLikeToString(tokenAddress); - const exchange = Exchange__factory.connect(exchangeAddress, provider) + const exchange = Exchange__factory.connect(exchangeAddress, provider); const exchangeBalance = await exchange.getBalance(tokenAddress, walletAddress); if (convertToNativeDecimals) { @@ -79,7 +79,7 @@ async function getExchangeBalanceNative( ) { walletAddress = await addressLikeToString(walletAddress); exchangeAddress = await addressLikeToString(exchangeAddress); - const exchange = Exchange__factory.connect(exchangeAddress, provider) + const exchange = Exchange__factory.connect(exchangeAddress, provider); const exchangeBalance = await exchange.getBalance(ZeroAddress, walletAddress); if (convertToNativeDecimals) { @@ -107,6 +107,25 @@ export async function getExchangeBalance( } } +export async function getExchangeAllowance( + tokenAddress: AddressLike, + walletAddress: AddressLike, + exchangeAddress: AddressLike, + provider: ethers.Provider +) { + if (typeof tokenAddress === "string" && tokenAddress === ZeroAddress) { + return 0n; + } else { + walletAddress = await addressLikeToString(walletAddress); + tokenAddress = await addressLikeToString(tokenAddress); + + const tokenContract = ERC20__factory.connect(tokenAddress, provider); + let allowance = await tokenContract.allowance(walletAddress, exchangeAddress); + + return allowance; + } +} + async function getWalletBalanceERC20( tokenAddress: AddressLike, walletAddress: AddressLike, @@ -172,5 +191,9 @@ export async function getTotalBalance( provider, convertToNativeDecimals ); - return walletBalance + exchangeBalance; + return { + walletBalance, + exchangeBalance, + totalBalance: walletBalance + exchangeBalance + } }