Refactor CLN stubs to allow for mocking

This commit is contained in:
Reckless_Satoshi 2023-11-08 14:56:36 +00:00 committed by Reckless_Satoshi
parent 68b1186277
commit bece7c7d4e

View File

@ -48,10 +48,6 @@ class CLNNode:
hold_channel = grpc.secure_channel(CLN_GRPC_HOLD_HOST, creds) hold_channel = grpc.secure_channel(CLN_GRPC_HOLD_HOST, creds)
node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds) node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds)
# Create the gRPC stub
hstub = hold_pb2_grpc.HoldStub(hold_channel)
nstub = node_pb2_grpc.NodeStub(node_channel)
payment_failure_context = { payment_failure_context = {
-1: "Catchall nonspecific error.", -1: "Catchall nonspecific error.",
201: "Already paid with this hash using different amount or destination.", 201: "Already paid with this hash using different amount or destination.",
@ -65,9 +61,9 @@ class CLNNode:
@classmethod @classmethod
def get_version(cls): def get_version(cls):
try: try:
nstub = node_pb2_grpc.NodeStub(cls.node_channel) nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
request = node_pb2.GetinfoRequest() request = node_pb2.GetinfoRequest()
response = nstub.Getinfo(request) response = nodestub.Getinfo(request)
return response.version return response.version
except Exception as e: except Exception as e:
print(f"Cannot get CLN version: {e}") print(f"Cannot get CLN version: {e}")
@ -77,8 +73,8 @@ class CLNNode:
def decode_payreq(cls, invoice): def decode_payreq(cls, invoice):
"""Decodes a lightning payment request (invoice)""" """Decodes a lightning payment request (invoice)"""
request = hold_pb2.DecodeBolt11Request(bolt11=invoice) request = hold_pb2.DecodeBolt11Request(bolt11=invoice)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = cls.hstub.DecodeBolt11(request) response = holdstub.DecodeBolt11(request)
return response return response
@classmethod @classmethod
@ -86,8 +82,8 @@ class CLNNode:
"""Returns estimated fee for onchain payouts""" """Returns estimated fee for onchain payouts"""
# feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html # feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html
request = node_pb2.FeeratesRequest(style="PERKB") request = node_pb2.FeeratesRequest(style="PERKB")
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = cls.nstub.Feerates(request) response = nodestub.Feerates(request)
# "opening" -> ~12 block target # "opening" -> ~12 block target
return { return {
@ -102,8 +98,8 @@ class CLNNode:
def wallet_balance(cls): def wallet_balance(cls):
"""Returns onchain balance""" """Returns onchain balance"""
request = node_pb2.ListfundsRequest() request = node_pb2.ListfundsRequest()
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = cls.nstub.ListFunds(request) response = nodestub.ListFunds(request)
unconfirmed_balance = 0 unconfirmed_balance = 0
confirmed_balance = 0 confirmed_balance = 0
@ -136,8 +132,8 @@ class CLNNode:
def channel_balance(cls): def channel_balance(cls):
"""Returns channels balance""" """Returns channels balance"""
request = node_pb2.ListpeerchannelsRequest() request = node_pb2.ListpeerchannelsRequest()
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = cls.nstub.ListPeerChannels(request) response = nodestub.ListPeerChannels(request)
local_balance_sat = 0 local_balance_sat = 0
remote_balance_sat = 0 remote_balance_sat = 0
@ -199,7 +195,8 @@ class CLNNode:
# Changing the state to "MEMPO" should be atomic with SendCoins. # Changing the state to "MEMPO" should be atomic with SendCoins.
onchainpayment.status = on_mempool_code onchainpayment.status = on_mempool_code
onchainpayment.save(update_fields=["status"]) onchainpayment.save(update_fields=["status"])
response = cls.nstub.Withdraw(request) nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.Withdraw(request)
if response.txid: if response.txid:
onchainpayment.txid = response.txid.hex() onchainpayment.txid = response.txid.hex()
@ -217,7 +214,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceCancelRequest( request = hold_pb2.HoldInvoiceCancelRequest(
payment_hash=bytes.fromhex(payment_hash) payment_hash=bytes.fromhex(payment_hash)
) )
response = cls.hstub.HoldInvoiceCancel(request) holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceCancel(request)
return response.state == hold_pb2.HoldInvoiceCancelResponse.Holdstate.CANCELED return response.state == hold_pb2.HoldInvoiceCancelResponse.Holdstate.CANCELED
@ -227,7 +225,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceSettleRequest( request = hold_pb2.HoldInvoiceSettleRequest(
payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest() payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest()
) )
response = cls.hstub.HoldInvoiceSettle(request) holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceSettle(request)
return response.state == hold_pb2.HoldInvoiceSettleResponse.Holdstate.SETTLED return response.state == hold_pb2.HoldInvoiceSettleResponse.Holdstate.SETTLED
@ -260,7 +259,8 @@ class CLNNode:
cltv=cltv_expiry_blocks, cltv=cltv_expiry_blocks,
preimage=preimage, # preimage is actually optional in cln, as cln would generate one by default preimage=preimage, # preimage is actually optional in cln, as cln would generate one by default
) )
response = cls.hstub.HoldInvoice(request) holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoice(request)
hold_payment["invoice"] = response.bolt11 hold_payment["invoice"] = response.bolt11
payreq_decoded = cls.decode_payreq(hold_payment["invoice"]) payreq_decoded = cls.decode_payreq(hold_payment["invoice"])
@ -284,7 +284,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceLookupRequest( request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
response = cls.hstub.HoldInvoiceLookup(request) holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceLookup(request)
# Will fail if 'unable to locate invoice'. Happens if invoice expiry # Will fail if 'unable to locate invoice'. Happens if invoice expiry
# time has passed (but these are 15% padded at the moment). Should catch it # time has passed (but these are 15% padded at the moment). Should catch it
@ -324,7 +325,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceLookupRequest( request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
response = cls.hstub.HoldInvoiceLookup(request) holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceLookup(request)
status = cln_response_state_to_lnpayment_status[response.state] status = cln_response_state_to_lnpayment_status[response.state]
@ -345,7 +347,8 @@ class CLNNode:
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
try: try:
response2 = cls.nstub.ListInvoices(request2).invoices nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response2 = nodestub.ListInvoices(request2).invoices
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
@ -482,7 +485,8 @@ class CLNNode:
) )
try: try:
response = cls.nstub.Pay(request) nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.Pay(request)
if response.status == node_pb2.PayResponse.PayStatus.COMPLETE: if response.status == node_pb2.PayResponse.PayStatus.COMPLETE:
lnpayment.status = LNPayment.Status.SUCCED lnpayment.status = LNPayment.Status.SUCCED
@ -540,7 +544,8 @@ class CLNNode:
) )
while True: while True:
try: try:
response_listpays = cls.nstub.ListPays(request_listpays) nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response_listpays = nodestub.ListPays(request_listpays)
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
time.sleep(2) time.sleep(2)
@ -562,8 +567,8 @@ class CLNNode:
lnpayment.save(update_fields=["in_flight", "status"]) lnpayment.save(update_fields=["in_flight", "status"])
order.update_status(Order.Status.PAY) order.update_status(Order.Status.PAY)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = cls.nstub.Pay(request) response = nodestub.Pay(request)
if response.status == node_pb2.PayResponse.PayStatus.PENDING: if response.status == node_pb2.PayResponse.PayStatus.PENDING:
print(f"Order: {order.id} IN_FLIGHT. Hash {hash}") print(f"Order: {order.id} IN_FLIGHT. Hash {hash}")
@ -758,9 +763,11 @@ class CLNNode:
) )
) )
if sign: if sign:
self_pubkey = cls.nstub.Getinfo(node_pb2.GetinfoRequest()).id nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
self_pubkey = nodestub.Getinfo(node_pb2.GetinfoRequest()).id
timestamp = struct.pack(">i", int(time.time())) timestamp = struct.pack(">i", int(time.time()))
signature = cls.nstub.SignMessage( nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
signature = nodestub.SignMessage(
node_pb2.SignmessageRequest( node_pb2.SignmessageRequest(
message=( message=(
bytes.fromhex(self_pubkey) bytes.fromhex(self_pubkey)
@ -791,7 +798,8 @@ class CLNNode:
retry_for=timeout, retry_for=timeout,
amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1000), amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1000),
) )
response = cls.nstub.KeySend(request) nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.KeySend(request)
keysend_payment["preimage"] = response.payment_preimage.hex() keysend_payment["preimage"] = response.payment_preimage.hex()
keysend_payment["payment_hash"] = response.payment_hash.hex() keysend_payment["payment_hash"] = response.payment_hash.hex()
@ -800,7 +808,8 @@ class CLNNode:
payment_hash=response.payment_hash, timeout=timeout payment_hash=response.payment_hash, timeout=timeout
) )
try: try:
waitresp = cls.nstub.WaitSendPay(waitreq) nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
waitresp = nodestub.WaitSendPay(waitreq)
keysend_payment["fee"] = ( keysend_payment["fee"] = (
float(waitresp.amount_sent_msat.msat - waitresp.amount_msat.msat) float(waitresp.amount_sent_msat.msat - waitresp.amount_msat.msat)
/ 1000 / 1000
@ -833,7 +842,8 @@ class CLNNode:
payment_hash=bytes.fromhex(payment_hash) payment_hash=bytes.fromhex(payment_hash)
) )
try: try:
response = cls.hstub.HoldInvoiceLookup(request) holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceLookup(request)
except Exception as e: except Exception as e:
if "Timed out" in str(e): if "Timed out" in str(e):
return False return False