Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add Async #81

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion offchain/web3/contract_caller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from logging import Logger
from typing import Optional, Any

from eth_abi import encode_abi, decode_abi
from eth_utils import to_hex

from offchain.concurrency import parmap
from offchain.web3.contract_utils import function_signature_to_sighash
from offchain.web3.jsonrpc import EthereumJSONRPC
from offchain.web3.jsonrpc import EthereumJSONRPC, AsyncEthereumJSONRPC

CHUNK_SIZE = 500

Expand Down Expand Up @@ -183,3 +185,178 @@ def decode_response(self, response: dict, return_types: list[str]) -> Optional[A

except Exception:
return None


class AsyncContractCaller:
def __init__(self, rpc: Optional[AsyncEthereumJSONRPC] = None) -> None:
self.rpc = rpc or AsyncEthereumJSONRPC()

async def single_address_single_fn_many_args(
self,
address: str,
function_sig: str,
return_type: list[str],
args: list[list[Any]],
block_tag: Optional[str] = "latest",
chunk_size: int = CHUNK_SIZE,
**kwargs,
) -> list[Optional[Any]]:
"""Call a single function on a single address with many different permutations of arguments

Args:
address (str): address to call function on
function_sig (str): function signature (ex: "totalSupply()")
return_type (list[str]): return function signature (ex: ["uint256"])
args (list[list[Any]]): list of arguments passed in each fn call (ex: [[1], [2], [3]])
chunk_size (int, optional): number of calls to group in a single req. Defaults to 500.

Returns:
list[Optional[Any]]: list of returned values, mapped 1-1 with args
"""

req_params = [
self.request_builder(address, function_sig, args[i], block_tag, **kwargs) for i in range(len(args))
]
res = await self._call_batch_chunked(req_params, chunk_size)
return list(map(lambda r: self.decode_response(r, return_type), res))

async def single_address_many_fns_many_args(
self,
address: str,
function_sigs: list[str],
return_types: list[list[str]],
args: list[list[Any]],
block_tag: Optional[str] = "latest",
chunk_size: int = CHUNK_SIZE,
) -> dict[str, Optional[Any]]:
"""Call many functions on a single addresses with differnt arguments per function

Args:
address (str): address to call function on
function_sigs (list[str]): list of fn signature (ex: ["totalSupply()", "symbol()"])
return_types (list[list[str]]): list of return function signature (ex: [["uint256"]])
args (list[list[Any]]): list of arguments passed in each fn call (ex: [[1], [2], [3]])
chunk_size (int, optional): [description]. Defaults to 500.

Returns:
dict[str, Optional[Any]]: dicts with fn names as keys (ex: {"totalSupply()": 1234})
"""
assert len(function_sigs) == len(args) and len(args) == len(
return_types
), "function names, return types, args must all be the same length"
req_params = [self.request_builder(address, function_sigs[i], args[i], block_tag) for i in range(len(args))]
res = await self._call_batch_chunked(req_params, chunk_size)
cleaned = list(map(lambda i: self.decode_response(res[i], return_types[i]), range(len(res))))
return {k: v for k, v in zip(function_sigs, cleaned)}

async def _call_batch_chunked(self, request_params: list[list[Any]], chunk_size: int = CHUNK_SIZE) -> list[Any]:
"""Perform concurrent batched requests by splitting a large batch into smaller chunks

Args:
request_params (list[list[Any]]): list of request parameters
chunk_size (int, optional): size at which to split requests. Defaults to 500.

Returns:
list[Any]: merged list of all data from the many requests
"""

async def call(params: list[list[Any]]) -> list[Any]:
return await self.rpc.call_batch_chunked("eth_call", params)

size = len(request_params)
if size < chunk_size:
return await call(request_params)

prev_offset, curr_offest = 0, chunk_size

chunks = []
while prev_offset < size:
chunks.append(request_params[prev_offset:curr_offest])
prev_offset = curr_offest
curr_offest = min(curr_offest + chunk_size, size)

results = await asyncio.gather(*[call(chunk) for chunk in chunks])
return [i for res in results for i in res]

def request_builder(
self,
address: str,
function_sig: str,
args: Optional[list] = None,
block_tag: Optional[str] = "latest",
**kwargs,
):
"""Request generation function. Can be overloaded via inheritance for custom RPC requests.

Args:
address (str): address to call
function_sig (str): function signature
args (Optional[list], optional): arguments for function if present

Returns:
[type]: [description]
"""
data = self.encode_params(function_sig, args, **kwargs)
return [{"to": address, "data": data}, block_tag]

def encode_params(
self,
function_sig: str,
args: Optional[list] = None,
arg_types: Optional[list] = None,
**kwargs,
) -> str:
"""Encode eth_call data by first taking the function sighash, then adding the encoded data

Args:w
function_sig (str): function signature
args (Optional[list], optional): arguments to pass. Defaults to None.

Returns:
str: [description]
"""
b = bytes.fromhex(function_signature_to_sighash(function_sig)[2:])

if args is not None:
if arg_types is None:
start = function_sig.find("(")
arg_types = function_sig[start:].strip("()").split(",")

if type(arg_types) == str:
arg_types = [arg_types]

b += encode_abi(arg_types, args)

return to_hex(b)

def decode_response(self, response: dict, return_types: list[str]) -> Optional[Any]:
"""Decode responses, filling None for any errored requests

Args:
response (dict): [description]
return_types (list[str]): [description]

Returns:
Optional[Any]: [description]
"""
try:
data = response.get("result")
if data is None:
return None

trimmed = data[2:]
if trimmed == "":
return None

parsed = decode_abi(return_types, bytes.fromhex(trimmed))
n_expected, n_received = len(return_types), len(parsed)

if n_expected == 1 or n_received == 1:
return parsed[0]
elif n_expected < n_received:
return parsed[:n_expected]
else:
return parsed

except Exception:
return None
73 changes: 73 additions & 0 deletions offchain/web3/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import requests
import requests.adapters
import asyncio
import aiohttp

from offchain.concurrency import parmap
from offchain.constants.providers import RPCProvider
Expand Down Expand Up @@ -94,3 +96,74 @@ def call_batch_chunked(

results = parmap(lambda chunk: self.call_batch(method, chunk), chunks)
return [i for res in results for i in res]


class AsyncEthereumJSONRPC:
def __init__(
self,
provider_url: Optional[str] = None,
) -> None:
self.sess = aiohttp.ClientSession()
self.url = provider_url or RPCProvider.LLAMA_NODES_MAINNET

def __payload_factory(self, method: str, params: list[Any], id: int) -> RPCPayload:
return {"method": method, "params": params, "id": id, "jsonrpc": "2.0"}

@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=5),
)
async def call(self, method: str, params: list[dict]) -> dict:
try:
payload = self.__payload_factory(method, params, 1)
async with self.sess.post(self.url, json=payload) as resp:
resp.raise_for_status()
return await resp.json()
except Exception as e:
logger.error(
f"Caught exception while making rpc call. Method: {method}. Params: {params}. Retrying. Error: {e}"
)
raise

@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=5),
)
async def call_batch(self, method: str, params: list[list[Any]]) -> list[dict]:
try:
payload = [self.__payload_factory(method, param, i) for i, param in enumerate(params)]
async with self.sess.post(self.url, json=payload) as resp:
resp.raise_for_status()
return await resp.json()
except Exception as e:
logger.error(
f"Caught exception while making batch rpc call. "
f"Method: {method}. Params: {params}. Retrying. Error: {e}"
# noqa
)
raise

@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=5),
)
async def call_batch_chunked(
self,
method: str,
params: list[list[Any]],
chunk_size: Optional[int] = MAX_REQUEST_BATCH_SIZE,
) -> list[dict]:
size = len(params)
if size < chunk_size:
return await self.call_batch(method, params)

prev_offset, curr_offset = 0, chunk_size

chunks = []
while prev_offset < size:
chunks.append(params[prev_offset:curr_offset])
prev_offset = curr_offset
curr_offset = min(curr_offset + chunk_size, size)

results = await asyncio.gather(*[self.call_batch(method, chunk) for chunk in chunks])
return [i for res in results for i in res]
Loading
Loading