From bece7c7d4eda8e1c5a061b3496d94820274ae367 Mon Sep 17 00:00:00 2001 From: Reckless_Satoshi Date: Wed, 8 Nov 2023 14:56:36 +0000 Subject: [PATCH] Refactor CLN stubs to allow for mocking --- api/lightning/cln.py | 70 +++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/api/lightning/cln.py b/api/lightning/cln.py index cb92dd15..928bf23f 100755 --- a/api/lightning/cln.py +++ b/api/lightning/cln.py @@ -48,10 +48,6 @@ class CLNNode: hold_channel = grpc.secure_channel(CLN_GRPC_HOLD_HOST, creds) node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds) - # Create the gRPC stub - hstub = hold_pb2_grpc.HoldStub(hold_channel) - nstub = node_pb2_grpc.NodeStub(node_channel) - payment_failure_context = { -1: "Catchall nonspecific error.", 201: "Already paid with this hash using different amount or destination.", @@ -65,9 +61,9 @@ class CLNNode: @classmethod def get_version(cls): try: - nstub = node_pb2_grpc.NodeStub(cls.node_channel) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) request = node_pb2.GetinfoRequest() - response = nstub.Getinfo(request) + response = nodestub.Getinfo(request) return response.version except Exception as e: print(f"Cannot get CLN version: {e}") @@ -77,8 +73,8 @@ class CLNNode: def decode_payreq(cls, invoice): """Decodes a lightning payment request (invoice)""" request = hold_pb2.DecodeBolt11Request(bolt11=invoice) - - response = cls.hstub.DecodeBolt11(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.DecodeBolt11(request) return response @classmethod @@ -86,8 +82,8 @@ class CLNNode: """Returns estimated fee for onchain payouts""" # feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html request = node_pb2.FeeratesRequest(style="PERKB") - - response = cls.nstub.Feerates(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.Feerates(request) # "opening" -> ~12 block target return { @@ -102,8 +98,8 @@ class CLNNode: def wallet_balance(cls): """Returns onchain balance""" request = node_pb2.ListfundsRequest() - - response = cls.nstub.ListFunds(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.ListFunds(request) unconfirmed_balance = 0 confirmed_balance = 0 @@ -136,8 +132,8 @@ class CLNNode: def channel_balance(cls): """Returns channels balance""" request = node_pb2.ListpeerchannelsRequest() - - response = cls.nstub.ListPeerChannels(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.ListPeerChannels(request) local_balance_sat = 0 remote_balance_sat = 0 @@ -199,7 +195,8 @@ class CLNNode: # Changing the state to "MEMPO" should be atomic with SendCoins. onchainpayment.status = on_mempool_code onchainpayment.save(update_fields=["status"]) - response = cls.nstub.Withdraw(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.Withdraw(request) if response.txid: onchainpayment.txid = response.txid.hex() @@ -217,7 +214,8 @@ class CLNNode: request = hold_pb2.HoldInvoiceCancelRequest( payment_hash=bytes.fromhex(payment_hash) ) - response = cls.hstub.HoldInvoiceCancel(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.HoldInvoiceCancel(request) return response.state == hold_pb2.HoldInvoiceCancelResponse.Holdstate.CANCELED @@ -227,7 +225,8 @@ class CLNNode: request = hold_pb2.HoldInvoiceSettleRequest( payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest() ) - response = cls.hstub.HoldInvoiceSettle(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.HoldInvoiceSettle(request) return response.state == hold_pb2.HoldInvoiceSettleResponse.Holdstate.SETTLED @@ -260,7 +259,8 @@ class CLNNode: cltv=cltv_expiry_blocks, preimage=preimage, # preimage is actually optional in cln, as cln would generate one by default ) - response = cls.hstub.HoldInvoice(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.HoldInvoice(request) hold_payment["invoice"] = response.bolt11 payreq_decoded = cls.decode_payreq(hold_payment["invoice"]) @@ -284,7 +284,8 @@ class CLNNode: request = hold_pb2.HoldInvoiceLookupRequest( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) - response = cls.hstub.HoldInvoiceLookup(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.HoldInvoiceLookup(request) # Will fail if 'unable to locate invoice'. Happens if invoice expiry # time has passed (but these are 15% padded at the moment). Should catch it @@ -324,7 +325,8 @@ class CLNNode: request = hold_pb2.HoldInvoiceLookupRequest( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) - response = cls.hstub.HoldInvoiceLookup(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.HoldInvoiceLookup(request) status = cln_response_state_to_lnpayment_status[response.state] @@ -345,7 +347,8 @@ class CLNNode: payment_hash=bytes.fromhex(lnpayment.payment_hash) ) try: - response2 = cls.nstub.ListInvoices(request2).invoices + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response2 = nodestub.ListInvoices(request2).invoices except Exception as e: print(str(e)) @@ -482,7 +485,8 @@ class CLNNode: ) try: - response = cls.nstub.Pay(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.Pay(request) if response.status == node_pb2.PayResponse.PayStatus.COMPLETE: lnpayment.status = LNPayment.Status.SUCCED @@ -540,7 +544,8 @@ class CLNNode: ) while True: try: - response_listpays = cls.nstub.ListPays(request_listpays) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response_listpays = nodestub.ListPays(request_listpays) except Exception as e: print(str(e)) time.sleep(2) @@ -562,8 +567,8 @@ class CLNNode: lnpayment.save(update_fields=["in_flight", "status"]) order.update_status(Order.Status.PAY) - - response = cls.nstub.Pay(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.Pay(request) if response.status == node_pb2.PayResponse.PayStatus.PENDING: print(f"Order: {order.id} IN_FLIGHT. Hash {hash}") @@ -758,9 +763,11 @@ class CLNNode: ) ) if sign: - self_pubkey = cls.nstub.Getinfo(node_pb2.GetinfoRequest()).id + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + self_pubkey = nodestub.Getinfo(node_pb2.GetinfoRequest()).id timestamp = struct.pack(">i", int(time.time())) - signature = cls.nstub.SignMessage( + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + signature = nodestub.SignMessage( node_pb2.SignmessageRequest( message=( bytes.fromhex(self_pubkey) @@ -791,7 +798,8 @@ class CLNNode: retry_for=timeout, amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1000), ) - response = cls.nstub.KeySend(request) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + response = nodestub.KeySend(request) keysend_payment["preimage"] = response.payment_preimage.hex() keysend_payment["payment_hash"] = response.payment_hash.hex() @@ -800,7 +808,8 @@ class CLNNode: payment_hash=response.payment_hash, timeout=timeout ) try: - waitresp = cls.nstub.WaitSendPay(waitreq) + nodestub = node_pb2_grpc.NodeStub(cls.node_channel) + waitresp = nodestub.WaitSendPay(waitreq) keysend_payment["fee"] = ( float(waitresp.amount_sent_msat.msat - waitresp.amount_msat.msat) / 1000 @@ -833,7 +842,8 @@ class CLNNode: payment_hash=bytes.fromhex(payment_hash) ) try: - response = cls.hstub.HoldInvoiceLookup(request) + holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel) + response = holdstub.HoldInvoiceLookup(request) except Exception as e: if "Timed out" in str(e): return False