Refactor gRPC and mocks. Add coordinator info test.

This commit is contained in:
Reckless_Satoshi 2023-11-08 14:25:34 +00:00 committed by Reckless_Satoshi
parent 89ae6cd4a6
commit 4efc59d416
7 changed files with 330 additions and 182 deletions

View File

@ -10,10 +10,7 @@ import ring
from decouple import config from decouple import config
from django.utils import timezone from django.utils import timezone
from . import hold_pb2 as holdrpc from . import hold_pb2, hold_pb2_grpc, node_pb2, node_pb2_grpc
from . import hold_pb2_grpc as holdstub
from . import node_pb2 as noderpc
from . import node_pb2_grpc as nodestub
from . import primitives_pb2 as primitives__pb2 from . import primitives_pb2 as primitives__pb2
####### #######
@ -52,11 +49,8 @@ class CLNNode:
node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds) node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds)
# Create the gRPC stub # Create the gRPC stub
hstub = holdstub.HoldStub(hold_channel) hstub = hold_pb2_grpc.HoldStub(hold_channel)
nstub = nodestub.NodeStub(node_channel) nstub = node_pb2_grpc.NodeStub(node_channel)
holdrpc = holdrpc
noderpc = noderpc
payment_failure_context = { payment_failure_context = {
-1: "Catchall nonspecific error.", -1: "Catchall nonspecific error.",
@ -71,19 +65,18 @@ class CLNNode:
@classmethod @classmethod
def get_version(cls): def get_version(cls):
try: try:
request = noderpc.GetinfoRequest() nstub = node_pb2_grpc.NodeStub(cls.node_channel)
print(request) request = node_pb2.GetinfoRequest()
response = cls.nstub.Getinfo(request) response = nstub.Getinfo(request)
print(response)
return response.version return response.version
except Exception as e: except Exception as e:
print(e) print(f"Cannot get CLN version: {e}")
return None return None
@classmethod @classmethod
def decode_payreq(cls, invoice): def decode_payreq(cls, invoice):
"""Decodes a lightning payment request (invoice)""" """Decodes a lightning payment request (invoice)"""
request = holdrpc.DecodeBolt11Request(bolt11=invoice) request = hold_pb2.DecodeBolt11Request(bolt11=invoice)
response = cls.hstub.DecodeBolt11(request) response = cls.hstub.DecodeBolt11(request)
return response return response
@ -92,7 +85,7 @@ class CLNNode:
def estimate_fee(cls, amount_sats, target_conf=2, min_confs=1): def estimate_fee(cls, amount_sats, target_conf=2, min_confs=1):
"""Returns estimated fee for onchain payouts""" """Returns estimated fee for onchain payouts"""
# feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html # feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html
request = noderpc.FeeratesRequest(style="PERKB") request = node_pb2.FeeratesRequest(style="PERKB")
response = cls.nstub.Feerates(request) response = cls.nstub.Feerates(request)
@ -108,7 +101,7 @@ class CLNNode:
@classmethod @classmethod
def wallet_balance(cls): def wallet_balance(cls):
"""Returns onchain balance""" """Returns onchain balance"""
request = noderpc.ListfundsRequest() request = node_pb2.ListfundsRequest()
response = cls.nstub.ListFunds(request) response = cls.nstub.ListFunds(request)
@ -119,13 +112,13 @@ class CLNNode:
if not utxo.reserved: if not utxo.reserved:
if ( if (
utxo.status utxo.status
== noderpc.ListfundsOutputs.ListfundsOutputsStatus.UNCONFIRMED == node_pb2.ListfundsOutputs.ListfundsOutputsStatus.UNCONFIRMED
): ):
unconfirmed_balance += utxo.amount_msat.msat // 1_000 unconfirmed_balance += utxo.amount_msat.msat // 1_000
total_balance += utxo.amount_msat.msat // 1_000 total_balance += utxo.amount_msat.msat // 1_000
elif ( elif (
utxo.status utxo.status
== noderpc.ListfundsOutputs.ListfundsOutputsStatus.CONFIRMED == node_pb2.ListfundsOutputs.ListfundsOutputsStatus.CONFIRMED
): ):
confirmed_balance += utxo.amount_msat.msat // 1_000 confirmed_balance += utxo.amount_msat.msat // 1_000
total_balance += utxo.amount_msat.msat // 1_000 total_balance += utxo.amount_msat.msat // 1_000
@ -142,7 +135,7 @@ class CLNNode:
@classmethod @classmethod
def channel_balance(cls): def channel_balance(cls):
"""Returns channels balance""" """Returns channels balance"""
request = noderpc.ListpeerchannelsRequest() request = node_pb2.ListpeerchannelsRequest()
response = cls.nstub.ListPeerChannels(request) response = cls.nstub.ListPeerChannels(request)
@ -153,7 +146,7 @@ class CLNNode:
for channel in response.channels: for channel in response.channels:
if ( if (
channel.state channel.state
== noderpc.ListpeerchannelsChannels.ListpeerchannelsChannelsState.CHANNELD_NORMAL == node_pb2.ListpeerchannelsChannels.ListpeerchannelsChannelsState.CHANNELD_NORMAL
): ):
local_balance_sat += channel.to_us_msat.msat // 1_000 local_balance_sat += channel.to_us_msat.msat // 1_000
remote_balance_sat += ( remote_balance_sat += (
@ -162,12 +155,12 @@ class CLNNode:
for htlc in channel.htlcs: for htlc in channel.htlcs:
if ( if (
htlc.direction htlc.direction
== noderpc.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.IN == node_pb2.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.IN
): ):
unsettled_local_balance += htlc.amount_msat.msat // 1_000 unsettled_local_balance += htlc.amount_msat.msat // 1_000
elif ( elif (
htlc.direction htlc.direction
== noderpc.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.OUT == node_pb2.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.OUT
): ):
unsettled_remote_balance += htlc.amount_msat.msat // 1_000 unsettled_remote_balance += htlc.amount_msat.msat // 1_000
@ -185,7 +178,7 @@ class CLNNode:
if DISABLE_ONCHAIN or onchainpayment.sent_satoshis > MAX_SWAP_AMOUNT: if DISABLE_ONCHAIN or onchainpayment.sent_satoshis > MAX_SWAP_AMOUNT:
return False return False
request = noderpc.WithdrawRequest( request = node_pb2.WithdrawRequest(
destination=onchainpayment.address, destination=onchainpayment.address,
satoshi=primitives__pb2.AmountOrAll( satoshi=primitives__pb2.AmountOrAll(
amount=primitives__pb2.Amount(msat=onchainpayment.sent_satoshis * 1_000) amount=primitives__pb2.Amount(msat=onchainpayment.sent_satoshis * 1_000)
@ -221,22 +214,22 @@ class CLNNode:
@classmethod @classmethod
def cancel_return_hold_invoice(cls, payment_hash): def cancel_return_hold_invoice(cls, payment_hash):
"""Cancels or returns a hold invoice""" """Cancels or returns a hold invoice"""
request = holdrpc.HoldInvoiceCancelRequest( request = hold_pb2.HoldInvoiceCancelRequest(
payment_hash=bytes.fromhex(payment_hash) payment_hash=bytes.fromhex(payment_hash)
) )
response = cls.hstub.HoldInvoiceCancel(request) response = cls.hstub.HoldInvoiceCancel(request)
return response.state == holdrpc.HoldInvoiceCancelResponse.Holdstate.CANCELED return response.state == hold_pb2.HoldInvoiceCancelResponse.Holdstate.CANCELED
@classmethod @classmethod
def settle_hold_invoice(cls, preimage): def settle_hold_invoice(cls, preimage):
"""settles a hold invoice""" """settles a hold invoice"""
request = holdrpc.HoldInvoiceSettleRequest( request = hold_pb2.HoldInvoiceSettleRequest(
payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest() payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest()
) )
response = cls.hstub.HoldInvoiceSettle(request) response = cls.hstub.HoldInvoiceSettle(request)
return response.state == holdrpc.HoldInvoiceSettleResponse.Holdstate.SETTLED return response.state == hold_pb2.HoldInvoiceSettleResponse.Holdstate.SETTLED
@classmethod @classmethod
def gen_hold_invoice( def gen_hold_invoice(
@ -259,7 +252,7 @@ class CLNNode:
# The preimage is a random hash of 256 bits entropy # The preimage is a random hash of 256 bits entropy
preimage = hashlib.sha256(secrets.token_bytes(nbytes=32)).digest() preimage = hashlib.sha256(secrets.token_bytes(nbytes=32)).digest()
request = holdrpc.HoldInvoiceRequest( request = hold_pb2.HoldInvoiceRequest(
description=description, description=description,
amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1_000), amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1_000),
label=f"Order:{order_id}-{lnpayment_concept}-{time}", label=f"Order:{order_id}-{lnpayment_concept}-{time}",
@ -288,7 +281,7 @@ class CLNNode:
"""Checks if hold invoice is locked""" """Checks if hold invoice is locked"""
from api.models import LNPayment from api.models import LNPayment
request = holdrpc.HoldInvoiceLookupRequest( request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
response = cls.hstub.HoldInvoiceLookup(request) response = cls.hstub.HoldInvoiceLookup(request)
@ -296,13 +289,13 @@ class CLNNode:
# Will fail if 'unable to locate invoice'. Happens if invoice expiry # Will fail if 'unable to locate invoice'. Happens if invoice expiry
# time has passed (but these are 15% padded at the moment). Should catch it # time has passed (but these are 15% padded at the moment). Should catch it
# and report back that the invoice has expired (better robustness) # and report back that the invoice has expired (better robustness)
if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.OPEN: if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.OPEN:
pass pass
if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.SETTLED: if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.SETTLED:
pass pass
if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.CANCELED: if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.CANCELED:
pass pass
if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.ACCEPTED: if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.ACCEPTED:
lnpayment.expiry_height = response.htlc_expiry lnpayment.expiry_height = response.htlc_expiry
lnpayment.status = LNPayment.Status.LOCKED lnpayment.status = LNPayment.Status.LOCKED
lnpayment.save(update_fields=["expiry_height", "status"]) lnpayment.save(update_fields=["expiry_height", "status"])
@ -328,7 +321,7 @@ class CLNNode:
try: try:
# this is similar to LNNnode.validate_hold_invoice_locked # this is similar to LNNnode.validate_hold_invoice_locked
request = holdrpc.HoldInvoiceLookupRequest( request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
response = cls.hstub.HoldInvoiceLookup(request) response = cls.hstub.HoldInvoiceLookup(request)
@ -348,7 +341,7 @@ class CLNNode:
# (cln-grpc-hodl has separate state for hodl-invoices, which it forgets after an invoice expired more than an hour ago) # (cln-grpc-hodl has separate state for hodl-invoices, which it forgets after an invoice expired more than an hour ago)
if "empty result for listdatastore_state" in str(e): if "empty result for listdatastore_state" in str(e):
print(str(e)) print(str(e))
request2 = noderpc.ListinvoicesRequest( request2 = node_pb2.ListinvoicesRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
try: try:
@ -358,12 +351,12 @@ class CLNNode:
if ( if (
response2[0].status response2[0].status
== noderpc.ListinvoicesInvoices.ListinvoicesInvoicesStatus.PAID == node_pb2.ListinvoicesInvoices.ListinvoicesInvoicesStatus.PAID
): ):
status = LNPayment.Status.SETLED status = LNPayment.Status.SETLED
elif ( elif (
response2[0].status response2[0].status
== noderpc.ListinvoicesInvoices.ListinvoicesInvoicesStatus.EXPIRED == node_pb2.ListinvoicesInvoices.ListinvoicesInvoicesStatus.EXPIRED
): ):
status = LNPayment.Status.CANCEL status = LNPayment.Status.CANCEL
else: else:
@ -482,7 +475,7 @@ class CLNNode:
) )
) # 200 ppm or 10 sats ) # 200 ppm or 10 sats
timeout_seconds = int(config("REWARDS_TIMEOUT_SECONDS")) timeout_seconds = int(config("REWARDS_TIMEOUT_SECONDS"))
request = noderpc.PayRequest( request = node_pb2.PayRequest(
bolt11=lnpayment.invoice, bolt11=lnpayment.invoice,
maxfee=primitives__pb2.Amount(msat=fee_limit_sat * 1_000), maxfee=primitives__pb2.Amount(msat=fee_limit_sat * 1_000),
retry_for=timeout_seconds, retry_for=timeout_seconds,
@ -491,7 +484,7 @@ class CLNNode:
try: try:
response = cls.nstub.Pay(request) response = cls.nstub.Pay(request)
if response.status == noderpc.PayResponse.PayStatus.COMPLETE: if response.status == node_pb2.PayResponse.PayStatus.COMPLETE:
lnpayment.status = LNPayment.Status.SUCCED lnpayment.status = LNPayment.Status.SUCCED
lnpayment.fee = ( lnpayment.fee = (
float(response.amount_sent_msat.msat - response.amount_msat.msat) float(response.amount_sent_msat.msat - response.amount_msat.msat)
@ -500,13 +493,13 @@ class CLNNode:
lnpayment.preimage = response.payment_preimage.hex() lnpayment.preimage = response.payment_preimage.hex()
lnpayment.save(update_fields=["fee", "status", "preimage"]) lnpayment.save(update_fields=["fee", "status", "preimage"])
return True, None return True, None
elif response.status == noderpc.PayResponse.PayStatus.PENDING: elif response.status == node_pb2.PayResponse.PayStatus.PENDING:
failure_reason = "Payment isn't failed (yet)" failure_reason = "Payment isn't failed (yet)"
lnpayment.failure_reason = LNPayment.FailureReason.NOTYETF lnpayment.failure_reason = LNPayment.FailureReason.NOTYETF
lnpayment.status = LNPayment.Status.FLIGHT lnpayment.status = LNPayment.Status.FLIGHT
lnpayment.save(update_fields=["failure_reason", "status"]) lnpayment.save(update_fields=["failure_reason", "status"])
return False, failure_reason return False, failure_reason
else: # response.status == noderpc.PayResponse.PayStatus.FAILED else: # response.status == node_pb2.PayResponse.PayStatus.FAILED
failure_reason = "All possible routes were tried and failed permanently. Or were no routes to the destination at all." failure_reason = "All possible routes were tried and failed permanently. Or were no routes to the destination at all."
lnpayment.failure_reason = LNPayment.FailureReason.NOROUTE lnpayment.failure_reason = LNPayment.FailureReason.NOROUTE
lnpayment.status = LNPayment.Status.FAILRO lnpayment.status = LNPayment.Status.FAILRO
@ -530,7 +523,7 @@ class CLNNode:
# retry_for is not quite the same as a timeout. Pay can still take SIGNIFICANTLY longer to return if htlcs are stuck! # retry_for is not quite the same as a timeout. Pay can still take SIGNIFICANTLY longer to return if htlcs are stuck!
# allow_self_payment=True, No such thing in pay command and self_payments do not work with pay! # allow_self_payment=True, No such thing in pay command and self_payments do not work with pay!
request = noderpc.PayRequest( request = node_pb2.PayRequest(
bolt11=lnpayment.invoice, bolt11=lnpayment.invoice,
maxfee=primitives__pb2.Amount(msat=fee_limit_sat * 1_000), maxfee=primitives__pb2.Amount(msat=fee_limit_sat * 1_000),
retry_for=timeout_seconds, retry_for=timeout_seconds,
@ -542,7 +535,9 @@ class CLNNode:
return return
def watchpayment(): def watchpayment():
request_listpays = noderpc.ListpaysRequest(payment_hash=bytes.fromhex(hash)) request_listpays = node_pb2.ListpaysRequest(
payment_hash=bytes.fromhex(hash)
)
while True: while True:
try: try:
response_listpays = cls.nstub.ListPays(request_listpays) response_listpays = cls.nstub.ListPays(request_listpays)
@ -554,7 +549,7 @@ class CLNNode:
if ( if (
len(response_listpays.pays) == 0 len(response_listpays.pays) == 0
or response_listpays.pays[0].status or response_listpays.pays[0].status
!= noderpc.ListpaysPays.ListpaysPaysStatus.PENDING != node_pb2.ListpaysPays.ListpaysPaysStatus.PENDING
): ):
return response_listpays return response_listpays
else: else:
@ -570,14 +565,14 @@ class CLNNode:
response = cls.nstub.Pay(request) response = cls.nstub.Pay(request)
if response.status == noderpc.PayResponse.PayStatus.PENDING: if response.status == node_pb2.PayResponse.PayStatus.PENDING:
print(f"Order: {order.id} IN_FLIGHT. Hash {hash}") print(f"Order: {order.id} IN_FLIGHT. Hash {hash}")
watchpayment() watchpayment()
handle_response() handle_response()
if response.status == noderpc.PayResponse.PayStatus.FAILED: if response.status == node_pb2.PayResponse.PayStatus.FAILED:
lnpayment.status = LNPayment.Status.FAILRO lnpayment.status = LNPayment.Status.FAILRO
lnpayment.last_routing_time = timezone.now() lnpayment.last_routing_time = timezone.now()
lnpayment.routing_attempts += 1 lnpayment.routing_attempts += 1
@ -614,7 +609,7 @@ class CLNNode:
"context": f"payment failure reason: {cls.payment_failure_context[-1]}", "context": f"payment failure reason: {cls.payment_failure_context[-1]}",
} }
if response.status == noderpc.PayResponse.PayStatus.COMPLETE: if response.status == node_pb2.PayResponse.PayStatus.COMPLETE:
print(f"Order: {order.id} SUCCEEDED. Hash: {hash}") print(f"Order: {order.id} SUCCEEDED. Hash: {hash}")
lnpayment.status = LNPayment.Status.SUCCED lnpayment.status = LNPayment.Status.SUCCED
lnpayment.fee = ( lnpayment.fee = (
@ -702,7 +697,7 @@ class CLNNode:
if ( if (
len(last_payresponse.pays) > 0 len(last_payresponse.pays) > 0
and last_payresponse.pays[0].status and last_payresponse.pays[0].status
== noderpc.ListpaysPays.ListpaysPaysStatus.COMPLETE == node_pb2.ListpaysPays.ListpaysPaysStatus.COMPLETE
): ):
handle_response() handle_response()
else: else:
@ -763,10 +758,10 @@ class CLNNode:
) )
) )
if sign: if sign:
self_pubkey = cls.nstub.GetInfo(noderpc.GetinfoRequest()).id self_pubkey = cls.nstub.Getinfo(node_pb2.GetinfoRequest()).id
timestamp = struct.pack(">i", int(time.time())) timestamp = struct.pack(">i", int(time.time()))
signature = cls.nstub.SignMessage( signature = cls.nstub.SignMessage(
noderpc.SignmessageRequest( node_pb2.SignmessageRequest(
message=( message=(
bytes.fromhex(self_pubkey) bytes.fromhex(self_pubkey)
+ bytes.fromhex(target_pubkey) + bytes.fromhex(target_pubkey)
@ -789,7 +784,7 @@ class CLNNode:
# no maxfee for Keysend # no maxfee for Keysend
maxfeepercent = (routing_budget_sats / num_satoshis) * 100 maxfeepercent = (routing_budget_sats / num_satoshis) * 100
request = noderpc.KeysendRequest( request = node_pb2.KeysendRequest(
destination=bytes.fromhex(target_pubkey), destination=bytes.fromhex(target_pubkey),
extratlvs=primitives__pb2.TlvStream(entries=custom_records), extratlvs=primitives__pb2.TlvStream(entries=custom_records),
maxfeepercent=maxfeepercent, maxfeepercent=maxfeepercent,
@ -801,7 +796,7 @@ class CLNNode:
keysend_payment["preimage"] = response.payment_preimage.hex() keysend_payment["preimage"] = response.payment_preimage.hex()
keysend_payment["payment_hash"] = response.payment_hash.hex() keysend_payment["payment_hash"] = response.payment_hash.hex()
waitreq = noderpc.WaitsendpayRequest( waitreq = node_pb2.WaitsendpayRequest(
payment_hash=response.payment_hash, timeout=timeout payment_hash=response.payment_hash, timeout=timeout
) )
try: try:
@ -834,7 +829,7 @@ class CLNNode:
@classmethod @classmethod
def double_check_htlc_is_settled(cls, payment_hash): def double_check_htlc_is_settled(cls, payment_hash):
"""Just as it sounds. Better safe than sorry!""" """Just as it sounds. Better safe than sorry!"""
request = holdrpc.HoldInvoiceLookupRequest( request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(payment_hash) payment_hash=bytes.fromhex(payment_hash)
) )
try: try:
@ -845,4 +840,4 @@ class CLNNode:
else: else:
raise e raise e
return response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.SETTLED return response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.SETTLED

View File

@ -11,16 +11,18 @@ import ring
from decouple import config from decouple import config
from django.utils import timezone from django.utils import timezone
from . import invoices_pb2 as invoicesrpc from . import (
from . import invoices_pb2_grpc as invoicesstub invoices_pb2,
from . import lightning_pb2 as lnrpc invoices_pb2_grpc,
from . import lightning_pb2_grpc as lightningstub lightning_pb2,
from . import router_pb2 as routerrpc lightning_pb2_grpc,
from . import router_pb2_grpc as routerstub router_pb2,
from . import signer_pb2 as signerrpc router_pb2_grpc,
from . import signer_pb2_grpc as signerstub signer_pb2,
from . import verrpc_pb2 as verrpc signer_pb2_grpc,
from . import verrpc_pb2_grpc as verstub verrpc_pb2,
verrpc_pb2_grpc,
)
####### #######
# Works with LND (c-lightning in the future for multi-vendor resilience) # Works with LND (c-lightning in the future for multi-vendor resilience)
@ -67,12 +69,6 @@ class LNDNode:
combined_creds = grpc.composite_channel_credentials(ssl_creds, auth_creds) combined_creds = grpc.composite_channel_credentials(ssl_creds, auth_creds)
channel = grpc.secure_channel(LND_GRPC_HOST, combined_creds) channel = grpc.secure_channel(LND_GRPC_HOST, combined_creds)
lightningstub = lightningstub.LightningStub(channel)
invoicesstub = invoicesstub.InvoicesStub(channel)
routerstub = routerstub.RouterStub(channel)
signerstub = signerstub.SignerStub(channel)
verstub = verstub.VersionerStub(channel)
payment_failure_context = { payment_failure_context = {
0: "Payment isn't failed (yet)", 0: "Payment isn't failed (yet)",
1: "There are more routes to try, but the payment timeout was exceeded.", 1: "There are more routes to try, but the payment timeout was exceeded.",
@ -85,8 +81,9 @@ class LNDNode:
@classmethod @classmethod
def get_version(cls): def get_version(cls):
try: try:
request = verrpc.VersionRequest() request = verrpc_pb2.VersionRequest()
response = cls.verstub.GetVersion(request) verstub = verrpc_pb2_grpc.VersionerStub(cls.channel)
response = verstub.GetVersion(request)
log("verstub.GetVersion", request, response) log("verstub.GetVersion", request, response)
return "v" + response.version return "v" + response.version
except Exception as e: except Exception as e:
@ -96,33 +93,35 @@ class LNDNode:
@classmethod @classmethod
def decode_payreq(cls, invoice): def decode_payreq(cls, invoice):
"""Decodes a lightning payment request (invoice)""" """Decodes a lightning payment request (invoice)"""
request = lnrpc.PayReqString(pay_req=invoice) lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
response = cls.lightningstub.DecodePayReq(request) request = lightning_pb2.PayReqString(pay_req=invoice)
log("lightningstub.DecodePayReq", request, response) response = lightningstub.DecodePayReq(request)
log("lightning_pb2_grpc.DecodePayReq", request, response)
return response return response
@classmethod @classmethod
def estimate_fee(cls, amount_sats, target_conf=2, min_confs=1): def estimate_fee(cls, amount_sats, target_conf=2, min_confs=1):
"""Returns estimated fee for onchain payouts""" """Returns estimated fee for onchain payouts"""
lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
request = lnrpc.GetInfoRequest() request = lightning_pb2.GetInfoRequest()
response = lightningstub.GetInfo(request) response = lightningstub.GetInfo(request)
log("lightningstub.GetInfo", request, response) log("lightning_pb2_grpc.GetInfo", request, response)
if response.testnet: if response.testnet:
dummy_address = "tb1qehyqhruxwl2p5pt52k6nxj4v8wwc3f3pg7377x" dummy_address = "tb1qehyqhruxwl2p5pt52k6nxj4v8wwc3f3pg7377x"
else: else:
dummy_address = "bc1qgxwaqe4m9mypd7ltww53yv3lyxhcfnhzzvy5j3" dummy_address = "bc1qgxwaqe4m9mypd7ltww53yv3lyxhcfnhzzvy5j3"
# We assume segwit. Use hardcoded address as shortcut so there is no need of user inputs yet. # We assume segwit. Use hardcoded address as shortcut so there is no need of user inputs yet.
request = lnrpc.EstimateFeeRequest( request = lightning_pb2.EstimateFeeRequest(
AddrToAmount={dummy_address: amount_sats}, AddrToAmount={dummy_address: amount_sats},
target_conf=target_conf, target_conf=target_conf,
min_confs=min_confs, min_confs=min_confs,
spend_unconfirmed=False, spend_unconfirmed=False,
) )
response = cls.lightningstub.EstimateFee(request) lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
log("lightningstub.EstimateFee", request, response) response = lightningstub.EstimateFee(request)
log("lightning_pb2_grpc.EstimateFee", request, response)
return { return {
"mining_fee_sats": response.fee_sat, "mining_fee_sats": response.fee_sat,
@ -135,9 +134,10 @@ class LNDNode:
@classmethod @classmethod
def wallet_balance(cls): def wallet_balance(cls):
"""Returns onchain balance""" """Returns onchain balance"""
request = lnrpc.WalletBalanceRequest() lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
response = cls.lightningstub.WalletBalance(request) request = lightning_pb2.WalletBalanceRequest()
log("lightningstub.WalletBalance", request, response) response = lightningstub.WalletBalance(request)
log("lightning_pb2_grpc.WalletBalance", request, response)
return { return {
"total_balance": response.total_balance, "total_balance": response.total_balance,
@ -151,9 +151,10 @@ class LNDNode:
@classmethod @classmethod
def channel_balance(cls): def channel_balance(cls):
"""Returns channels balance""" """Returns channels balance"""
request = lnrpc.ChannelBalanceRequest() lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
response = cls.lightningstub.ChannelBalance(request) request = lightning_pb2.ChannelBalanceRequest()
log("lightningstub.ChannelBalance", request, response) response = lightningstub.ChannelBalance(request)
log("lightning_pb2_grpc.ChannelBalance", request, response)
return { return {
"local_balance": response.local_balance.sat, "local_balance": response.local_balance.sat,
@ -169,7 +170,7 @@ class LNDNode:
if DISABLE_ONCHAIN or onchainpayment.sent_satoshis > MAX_SWAP_AMOUNT: if DISABLE_ONCHAIN or onchainpayment.sent_satoshis > MAX_SWAP_AMOUNT:
return False return False
request = lnrpc.SendCoinsRequest( request = lightning_pb2.SendCoinsRequest(
addr=onchainpayment.address, addr=onchainpayment.address,
amount=int(onchainpayment.sent_satoshis), amount=int(onchainpayment.sent_satoshis),
sat_per_vbyte=int(onchainpayment.mining_fee_rate), sat_per_vbyte=int(onchainpayment.mining_fee_rate),
@ -187,8 +188,9 @@ class LNDNode:
# Changing the state to "MEMPO" should be atomic with SendCoins. # Changing the state to "MEMPO" should be atomic with SendCoins.
onchainpayment.status = on_mempool_code onchainpayment.status = on_mempool_code
onchainpayment.save(update_fields=["status"]) onchainpayment.save(update_fields=["status"])
response = cls.lightningstub.SendCoins(request) lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
log("lightningstub.SendCoins", request, response) response = lightningstub.SendCoins(request)
log("lightning_pb2_grpc.SendCoins", request, response)
if response.txid: if response.txid:
onchainpayment.txid = response.txid onchainpayment.txid = response.txid
@ -210,18 +212,22 @@ class LNDNode:
@classmethod @classmethod
def cancel_return_hold_invoice(cls, payment_hash): def cancel_return_hold_invoice(cls, payment_hash):
"""Cancels or returns a hold invoice""" """Cancels or returns a hold invoice"""
request = invoicesrpc.CancelInvoiceMsg(payment_hash=bytes.fromhex(payment_hash)) request = invoices_pb2.CancelInvoiceMsg(
response = cls.invoicesstub.CancelInvoice(request) payment_hash=bytes.fromhex(payment_hash)
log("invoicesstub.CancelInvoice", request, response) )
invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel)
response = invoicesstub.CancelInvoice(request)
log("invoices_pb2_grpc.CancelInvoice", request, response)
# Fix this: tricky because canceling sucessfully an invoice has no response. TODO # Fix this: tricky because canceling sucessfully an invoice has no response. TODO
return str(response) == "" # True if no response, false otherwise. return str(response) == "" # True if no response, false otherwise.
@classmethod @classmethod
def settle_hold_invoice(cls, preimage): def settle_hold_invoice(cls, preimage):
"""settles a hold invoice""" """settles a hold invoice"""
request = invoicesrpc.SettleInvoiceMsg(preimage=bytes.fromhex(preimage)) request = invoices_pb2.SettleInvoiceMsg(preimage=bytes.fromhex(preimage))
response = cls.invoicesstub.SettleInvoice(request) invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel)
log("invoicesstub.SettleInvoice", request, response) response = invoicesstub.SettleInvoice(request)
log("invoices_pb2_grpc.SettleInvoice", request, response)
# Fix this: tricky because settling sucessfully an invoice has None response. TODO # Fix this: tricky because settling sucessfully an invoice has None response. TODO
return str(response) == "" # True if no response, false otherwise. return str(response) == "" # True if no response, false otherwise.
@ -244,7 +250,7 @@ class LNDNode:
# Its hash is used to generate the hold invoice # Its hash is used to generate the hold invoice
r_hash = hashlib.sha256(preimage).digest() r_hash = hashlib.sha256(preimage).digest()
request = invoicesrpc.AddHoldInvoiceRequest( request = invoices_pb2.AddHoldInvoiceRequest(
memo=description, memo=description,
value=num_satoshis, value=num_satoshis,
hash=r_hash, hash=r_hash,
@ -253,8 +259,9 @@ class LNDNode:
), # actual expiry is padded by 50%, if tight, wrong client system clock will say invoice is expired. ), # actual expiry is padded by 50%, if tight, wrong client system clock will say invoice is expired.
cltv_expiry=cltv_expiry_blocks, cltv_expiry=cltv_expiry_blocks,
) )
response = cls.invoicesstub.AddHoldInvoice(request) invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel)
log("invoicesstub.AddHoldInvoice", request, response) response = invoicesstub.AddHoldInvoice(request)
log("invoices_pb2_grpc.AddHoldInvoice", request, response)
hold_payment["invoice"] = response.payment_request hold_payment["invoice"] = response.payment_request
payreq_decoded = cls.decode_payreq(hold_payment["invoice"]) payreq_decoded = cls.decode_payreq(hold_payment["invoice"])
@ -275,22 +282,25 @@ class LNDNode:
"""Checks if hold invoice is locked""" """Checks if hold invoice is locked"""
from api.models import LNPayment from api.models import LNPayment
request = invoicesrpc.LookupInvoiceMsg( request = invoices_pb2.LookupInvoiceMsg(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
response = cls.invoicesstub.LookupInvoiceV2(request) invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel)
log("invoicesstub.LookupInvoiceV2", request, response) response = invoicesstub.LookupInvoiceV2(request)
log("invoices_pb2_grpc.LookupInvoiceV2", request, response)
# Will fail if 'unable to locate invoice'. Happens if invoice expiry # Will fail if 'unable to locate invoice'. Happens if invoice expiry
# time has passed (but these are 15% padded at the moment). Should catch it # time has passed (but these are 15% padded at the moment). Should catch it
# and report back that the invoice has expired (better robustness) # and report back that the invoice has expired (better robustness)
if response.state == lnrpc.Invoice.InvoiceState.OPEN: # OPEN if response.state == lightning_pb2.Invoice.InvoiceState.OPEN: # OPEN
pass pass
if response.state == lnrpc.Invoice.InvoiceState.SETTLED: # SETTLED if response.state == lightning_pb2.Invoice.InvoiceState.SETTLED: # SETTLED
pass pass
if response.state == lnrpc.Invoice.InvoiceState.CANCELED: # CANCELED if response.state == lightning_pb2.Invoice.InvoiceState.CANCELED: # CANCELED
pass pass
if response.state == lnrpc.Invoice.InvoiceState.ACCEPTED: # ACCEPTED (LOCKED) if (
response.state == lightning_pb2.Invoice.InvoiceState.ACCEPTED
): # ACCEPTED (LOCKED)
lnpayment.expiry_height = response.htlcs[0].expiry_height lnpayment.expiry_height = response.htlcs[0].expiry_height
lnpayment.status = LNPayment.Status.LOCKED lnpayment.status = LNPayment.Status.LOCKED
lnpayment.save(update_fields=["expiry_height", "status"]) lnpayment.save(update_fields=["expiry_height", "status"])
@ -316,11 +326,12 @@ class LNDNode:
try: try:
# this is similar to LNNnode.validate_hold_invoice_locked # this is similar to LNNnode.validate_hold_invoice_locked
request = invoicesrpc.LookupInvoiceMsg( request = invoices_pb2.LookupInvoiceMsg(
payment_hash=bytes.fromhex(lnpayment.payment_hash) payment_hash=bytes.fromhex(lnpayment.payment_hash)
) )
response = cls.invoicesstub.LookupInvoiceV2(request) invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel)
log("invoicesstub.LookupInvoiceV2", request, response) response = invoicesstub.LookupInvoiceV2(request)
log("invoices_pb2_grpc.LookupInvoiceV2", request, response)
status = lnd_response_state_to_lnpayment_status[response.state] status = lnd_response_state_to_lnpayment_status[response.state]
@ -351,8 +362,9 @@ class LNDNode:
@classmethod @classmethod
def resetmc(cls): def resetmc(cls):
request = routerrpc.ResetMissionControlRequest() routerstub = router_pb2_grpc.RouterStub(cls.channel)
_ = cls.routerstub.ResetMissionControl(request) request = router_pb2.ResetMissionControlRequest()
_ = routerstub.ResetMissionControl(request)
return True return True
@classmethod @classmethod
@ -459,27 +471,28 @@ class LNDNode:
) )
) # 200 ppm or 10 sats ) # 200 ppm or 10 sats
timeout_seconds = int(config("REWARDS_TIMEOUT_SECONDS")) timeout_seconds = int(config("REWARDS_TIMEOUT_SECONDS"))
request = routerrpc.SendPaymentRequest( request = router_pb2.SendPaymentRequest(
payment_request=lnpayment.invoice, payment_request=lnpayment.invoice,
fee_limit_sat=fee_limit_sat, fee_limit_sat=fee_limit_sat,
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
) )
for response in cls.routerstub.SendPaymentV2(request): routerstub = router_pb2_grpc.RouterStub(cls.channel)
log("routerstub.SendPaymentV2", request, response) for response in routerstub.SendPaymentV2(request):
log("router_pb2_grpc.SendPaymentV2", request, response)
if ( if (
response.status == lnrpc.Payment.PaymentStatus.UNKNOWN response.status == lightning_pb2.Payment.PaymentStatus.UNKNOWN
): # Status 0 'UNKNOWN' ): # Status 0 'UNKNOWN'
# Not sure when this status happens # Not sure when this status happens
pass pass
if ( if (
response.status == lnrpc.Payment.PaymentStatus.IN_FLIGHT response.status == lightning_pb2.Payment.PaymentStatus.IN_FLIGHT
): # Status 1 'IN_FLIGHT' ): # Status 1 'IN_FLIGHT'
pass pass
if ( if (
response.status == lnrpc.Payment.PaymentStatus.FAILED response.status == lightning_pb2.Payment.PaymentStatus.FAILED
): # Status 3 'FAILED' ): # Status 3 'FAILED'
"""0 Payment isn't failed (yet). """0 Payment isn't failed (yet).
1 There are more routes to try, but the payment timeout was exceeded. 1 There are more routes to try, but the payment timeout was exceeded.
@ -495,7 +508,7 @@ class LNDNode:
return False, failure_reason return False, failure_reason
if ( if (
response.status == lnrpc.Payment.PaymentStatus.SUCCEEDED response.status == lightning_pb2.Payment.PaymentStatus.SUCCEEDED
): # STATUS 'SUCCEEDED' ): # STATUS 'SUCCEEDED'
lnpayment.status = LNPayment.Status.SUCCED lnpayment.status = LNPayment.Status.SUCCED
lnpayment.fee = float(response.fee_msat) / 1000 lnpayment.fee = float(response.fee_msat) / 1000
@ -515,7 +528,7 @@ class LNDNode:
hash = lnpayment.payment_hash hash = lnpayment.payment_hash
request = routerrpc.SendPaymentRequest( request = router_pb2.SendPaymentRequest(
payment_request=lnpayment.invoice, payment_request=lnpayment.invoice,
fee_limit_sat=fee_limit_sat, fee_limit_sat=fee_limit_sat,
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
@ -535,7 +548,7 @@ class LNDNode:
order.save(update_fields=["status"]) order.save(update_fields=["status"])
if ( if (
response.status == lnrpc.Payment.PaymentStatus.UNKNOWN response.status == lightning_pb2.Payment.PaymentStatus.UNKNOWN
): # Status 0 'UNKNOWN' ): # Status 0 'UNKNOWN'
# Not sure when this status happens # Not sure when this status happens
print(f"Order: {order.id} UNKNOWN. Hash {hash}") print(f"Order: {order.id} UNKNOWN. Hash {hash}")
@ -543,7 +556,7 @@ class LNDNode:
lnpayment.save(update_fields=["in_flight"]) lnpayment.save(update_fields=["in_flight"])
if ( if (
response.status == lnrpc.Payment.PaymentStatus.IN_FLIGHT response.status == lightning_pb2.Payment.PaymentStatus.IN_FLIGHT
): # Status 1 'IN_FLIGHT' ): # Status 1 'IN_FLIGHT'
print(f"Order: {order.id} IN_FLIGHT. Hash {hash}") print(f"Order: {order.id} IN_FLIGHT. Hash {hash}")
@ -556,7 +569,7 @@ class LNDNode:
lnpayment.save(update_fields=["last_routing_time"]) lnpayment.save(update_fields=["last_routing_time"])
if ( if (
response.status == lnrpc.Payment.PaymentStatus.FAILED response.status == lightning_pb2.Payment.PaymentStatus.FAILED
): # Status 3 'FAILED' ): # Status 3 'FAILED'
lnpayment.status = LNPayment.Status.FAILRO lnpayment.status = LNPayment.Status.FAILRO
lnpayment.last_routing_time = timezone.now() lnpayment.last_routing_time = timezone.now()
@ -599,7 +612,7 @@ class LNDNode:
} }
if ( if (
response.status == lnrpc.Payment.PaymentStatus.SUCCEEDED response.status == lightning_pb2.Payment.PaymentStatus.SUCCEEDED
): # Status 2 'SUCCEEDED' ): # Status 2 'SUCCEEDED'
print(f"Order: {order.id} SUCCEEDED. Hash: {hash}") print(f"Order: {order.id} SUCCEEDED. Hash: {hash}")
lnpayment.status = LNPayment.Status.SUCCED lnpayment.status = LNPayment.Status.SUCCED
@ -621,8 +634,9 @@ class LNDNode:
return results return results
try: try:
for response in cls.routerstub.SendPaymentV2(request): routerstub = router_pb2_grpc.RouterStub(cls.channel)
log("routerstub.SendPaymentV2", request, response) for response in routerstub.SendPaymentV2(request):
log("router_pb2_grpc.SendPaymentV2", request, response)
handle_response(response) handle_response(response)
except Exception as e: except Exception as e:
@ -630,12 +644,13 @@ class LNDNode:
print(f"Order: {order.id}. INVOICE EXPIRED. Hash: {hash}") print(f"Order: {order.id}. INVOICE EXPIRED. Hash: {hash}")
# An expired invoice can already be in-flight. Check. # An expired invoice can already be in-flight. Check.
try: try:
request = routerrpc.TrackPaymentRequest( request = router_pb2.TrackPaymentRequest(
payment_hash=bytes.fromhex(hash) payment_hash=bytes.fromhex(hash)
) )
for response in cls.routerstub.TrackPaymentV2(request): routerstub = router_pb2_grpc.RouterStub(cls.channel)
log("routerstub.TrackPaymentV2", request, response) for response in routerstub.TrackPaymentV2(request):
log("router_pb2_grpc.TrackPaymentV2", request, response)
handle_response(response, was_in_transit=True) handle_response(response, was_in_transit=True)
except Exception as e: except Exception as e:
@ -670,23 +685,25 @@ class LNDNode:
elif "payment is in transition" in str(e): elif "payment is in transition" in str(e):
print(f"Order: {order.id} ALREADY IN TRANSITION. Hash: {hash}.") print(f"Order: {order.id} ALREADY IN TRANSITION. Hash: {hash}.")
request = routerrpc.TrackPaymentRequest( request = router_pb2.TrackPaymentRequest(
payment_hash=bytes.fromhex(hash) payment_hash=bytes.fromhex(hash)
) )
for response in cls.routerstub.TrackPaymentV2(request): routerstub = router_pb2_grpc.RouterStub(cls.channel)
log("routerstub.TrackPaymentV2", request, response) for response in routerstub.TrackPaymentV2(request):
log("router_pb2_grpc.TrackPaymentV2", request, response)
handle_response(response, was_in_transit=True) handle_response(response, was_in_transit=True)
elif "invoice is already paid" in str(e): elif "invoice is already paid" in str(e):
print(f"Order: {order.id} ALREADY PAID. Hash: {hash}.") print(f"Order: {order.id} ALREADY PAID. Hash: {hash}.")
request = routerrpc.TrackPaymentRequest( request = router_pb2.TrackPaymentRequest(
payment_hash=bytes.fromhex(hash) payment_hash=bytes.fromhex(hash)
) )
for response in cls.routerstub.TrackPaymentV2(request): routerstub = router_pb2_grpc.RouterStub(cls.channel)
log("routerstub.TrackPaymentV2", request, response) for response in routerstub.TrackPaymentV2(request):
log("router_pb2_grpc.TrackPaymentV2", request, response)
handle_response(response) handle_response(response)
else: else:
@ -721,26 +738,28 @@ class LNDNode:
(34349334, bytes.fromhex(msg.encode("utf-8").hex())) (34349334, bytes.fromhex(msg.encode("utf-8").hex()))
) )
if sign: if sign:
self_pubkey = cls.lightningstub.GetInfo( lightningstub = lightning_pb2_grpc.LightningStub(cls.channel)
lnrpc.GetInfoRequest() self_pubkey = lightningstub.GetInfo(
lightning_pb2.GetInfoRequest()
).identity_pubkey ).identity_pubkey
timestamp = struct.pack(">i", int(time.time())) timestamp = struct.pack(">i", int(time.time()))
signature = cls.signerstub.SignMessage( signerstub = signer_pb2_grpc.SignerStub(cls.channel)
signerrpc.SignMessageReq( signature = signerstub.SignMessage(
signer_pb2.SignMessageReq(
msg=( msg=(
bytes.fromhex(self_pubkey) bytes.fromhex(self_pubkey)
+ bytes.fromhex(target_pubkey) + bytes.fromhex(target_pubkey)
+ timestamp + timestamp
+ bytes.fromhex(msg.encode("utf-8").hex()) + bytes.fromhex(msg.encode("utf-8").hex())
), ),
key_loc=signerrpc.KeyLocator(key_family=6, key_index=0), key_loc=signer_pb2.KeyLocator(key_family=6, key_index=0),
) )
).signature ).signature
custom_records.append((34349337, signature)) custom_records.append((34349337, signature))
custom_records.append((34349339, bytes.fromhex(self_pubkey))) custom_records.append((34349339, bytes.fromhex(self_pubkey)))
custom_records.append((34349343, timestamp)) custom_records.append((34349343, timestamp))
request = routerrpc.SendPaymentRequest( request = router_pb2.SendPaymentRequest(
dest=bytes.fromhex(target_pubkey), dest=bytes.fromhex(target_pubkey),
dest_custom_records=custom_records, dest_custom_records=custom_records,
fee_limit_sat=routing_budget_sats, fee_limit_sat=routing_budget_sats,
@ -749,17 +768,18 @@ class LNDNode:
payment_hash=bytes.fromhex(hashed_secret), payment_hash=bytes.fromhex(hashed_secret),
allow_self_payment=ALLOW_SELF_KEYSEND, allow_self_payment=ALLOW_SELF_KEYSEND,
) )
for response in cls.routerstub.SendPaymentV2(request): routerstub = router_pb2_grpc.RouterStub(cls.channel)
log("routerstub.SendPaymentV2", request, response) for response in routerstub.SendPaymentV2(request):
if response.status == lnrpc.Payment.PaymentStatus.IN_FLIGHT: log("router_pb2_grpc.SendPaymentV2", request, response)
if response.status == lightning_pb2.Payment.PaymentStatus.IN_FLIGHT:
keysend_payment["status"] = LNPayment.Status.FLIGHT keysend_payment["status"] = LNPayment.Status.FLIGHT
if response.status == lnrpc.Payment.PaymentStatus.SUCCEEDED: if response.status == lightning_pb2.Payment.PaymentStatus.SUCCEEDED:
keysend_payment["fee"] = float(response.fee_msat) / 1000 keysend_payment["fee"] = float(response.fee_msat) / 1000
keysend_payment["status"] = LNPayment.Status.SUCCED keysend_payment["status"] = LNPayment.Status.SUCCED
if response.status == lnrpc.Payment.PaymentStatus.FAILED: if response.status == lightning_pb2.Payment.PaymentStatus.FAILED:
keysend_payment["status"] = LNPayment.Status.FAILRO keysend_payment["status"] = LNPayment.Status.FAILRO
keysend_payment["failure_reason"] = response.failure_reason keysend_payment["failure_reason"] = response.failure_reason
if response.status == lnrpc.Payment.PaymentStatus.UNKNOWN: if response.status == lightning_pb2.Payment.PaymentStatus.UNKNOWN:
print("Unknown Error") print("Unknown Error")
except Exception as e: except Exception as e:
if "self-payments not allowed" in str(e): if "self-payments not allowed" in str(e):
@ -772,10 +792,13 @@ class LNDNode:
@classmethod @classmethod
def double_check_htlc_is_settled(cls, payment_hash): def double_check_htlc_is_settled(cls, payment_hash):
"""Just as it sounds. Better safe than sorry!""" """Just as it sounds. Better safe than sorry!"""
request = invoicesrpc.LookupInvoiceMsg(payment_hash=bytes.fromhex(payment_hash)) request = invoices_pb2.LookupInvoiceMsg(
response = cls.invoicesstub.LookupInvoiceV2(request) payment_hash=bytes.fromhex(payment_hash)
log("invoicesstub.LookupInvoiceV2", request, response) )
invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel)
response = invoicesstub.LookupInvoiceV2(request)
log("invoices_pb2_grpc.LookupInvoiceV2", request, response)
return ( return (
response.state == lnrpc.Invoice.InvoiceState.SETTLED response.state == lightning_pb2.Invoice.InvoiceState.SETTLED
) # LND states: 0 OPEN, 1 SETTLED, 3 ACCEPTED, GRPC_ERROR status 5 when CANCELED/returned ) # LND states: 0 OPEN, 1 SETTLED, 3 ACCEPTED, GRPC_ERROR status 5 when CANCELED/returned

View File

@ -1,7 +1,6 @@
from unittest.mock import MagicMock, Mock, mock_open, patch from unittest.mock import MagicMock, Mock, mock_open, patch
import numpy as np import numpy as np
from decouple import config
from django.test import TestCase from django.test import TestCase
from api.models import Order from api.models import Order
@ -22,6 +21,8 @@ from api.utils import (
verify_signed_message, verify_signed_message,
weighted_median, weighted_median,
) )
from tests.mocks.cln import MockNodeStub
from tests.mocks.lnd import MockVersionerStub
class TestUtils(TestCase): class TestUtils(TestCase):
@ -95,25 +96,19 @@ class TestUtils(TestCase):
mock_response_blockchain.json.assert_called_once() mock_response_blockchain.json.assert_called_once()
mock_response_yadio.json.assert_called_once() mock_response_yadio.json.assert_called_once()
LNVENDOR = config("LNVENDOR", cast=str, default="LND") @patch("api.lightning.lnd.verrpc_pb2_grpc.VersionerStub", MockVersionerStub)
def test_get_lnd_version(self):
if LNVENDOR == "LND":
@patch("api.lightning.lnd.LNDNode.get_version")
def test_get_lnd_version(self, mock_get_version):
mock_get_version.return_value = "v0.17.0-beta"
version = get_lnd_version() version = get_lnd_version()
self.assertEqual(version, "v0.17.0-beta") self.assertEqual(version, "v0.17.0-beta")
elif LNVENDOR == "CLN": @patch("api.lightning.cln.node_pb2_grpc.NodeStub", MockNodeStub)
def test_get_cln_version(self):
@patch("api.lightning.cln.CLNNode.get_version")
def test_get_cln_version(self, mock_get_version):
mock_get_version.return_value = "v23.08.1"
version = get_cln_version() version = get_cln_version()
self.assertEqual(version, "v23.08.1") self.assertEqual(version, "v23.08")
@patch("builtins.open", new_callable=mock_open, read_data="test_commit_hash") @patch(
"builtins.open", new_callable=mock_open, read_data="00000000000000000000 dev"
)
def test_get_robosats_commit(self, mock_file): def test_get_robosats_commit(self, mock_file):
# Call the get_robosats_commit function # Call the get_robosats_commit function
commit_hash = get_robosats_commit() commit_hash = get_robosats_commit()

58
tests/mocks/cln.py Normal file
View File

@ -0,0 +1,58 @@
from unittest.mock import MagicMock
# Mock up of CLN gRPC responses
class MockNodeStub:
def __init__(channel, other):
pass
def Getinfo(self, request):
response = MagicMock()
response.id = b"\002\202Y\300\330\2564\005\357\263\221;\300\266\326F\010}\370/\252&!v\221iM\251\241V\241\034\034"
response.alias = "ROBOSATS-TEST-CLN-v23.08"
response.color = "\002\202Y"
response.num_peers = 1
response.num_active_channels = 1
response.version = "v23.08"
response.lightning_dir = "/root/.lightning/testnet"
response.our_features.init = b"\010\240\000\n\002i\242"
response.our_features.node = b"\210\240\000\n\002i\242"
response.our_features.invoice = b"\002\000\000\002\002A\000"
response.blockheight = 2100000
response.network = "testnet"
response.fees_collected_msat.msat: 21000
response.address.item_type = "TORV3"
response.address.port = 19735
response.address.address = (
"21000000gwfmvmig5xlzc2yzm6uzisode5vhs7kyegwstu5hflhx5fid.onion"
)
response.binding.item_type = "IPV6"
response.binding.address = "127.0.0.1"
response.binding.port = 9736
return response
class MockHoldStub:
def __init__(channel, other):
pass
def HoldInvoiceLookup(self, request):
response = MagicMock()
return response
def HoldInvoice(self, request):
response = MagicMock()
return response
def HoldInvoiceSettle(self, request):
response = MagicMock()
return response
def HoldInvoiceCancel(self, request):
response = MagicMock()
return response
def DecodeBolt11(self, request):
response = MagicMock()
return response

View File

@ -17,20 +17,30 @@ class MockLightningStub:
def DecodePayReq(self, request): def DecodePayReq(self, request):
response = MagicMock() response = MagicMock()
if request.pay_req == "lntb17314....x": if (
response.destination = "00000000" request.pay_req
response.payment_hash = "00000000" == "lntb17310n1pj552mdpp50p2utzh7mpsf3uq7u7cws4a96tj3kyq54hchdkpw8zecamx9klrqd2j2pshjmt9de6zqun9vejhyetwvdjn5gphxs6nsvfe893z6wphvfsj6dryvymj6wp5xvuz6wp5xcukvdec8yukgcf49cs9g6rfwvs8qcted4jkuapq2ay5cnpqgefy2326g5syjn3qt984253q2aq5cnz92skzqcmgv43kkgr0dcs9ymmzdafkzarnyp5kvgr5dpjjqmr0vd4jqampwvs8xatrvdjhxumxw4kzugzfwss8w6tvdssxyefqw4hxcmmrddjkggpgveskjmpfyp6kumr9wdejq7t0w5sxx6r9v96zqmmjyp3kzmnrv4kzqatwd9kxzar9wfskcmre9ccqz52xqzwzsp5hkzegrhn6kegr33z8qfxtcudaklugygdrakgyy7va0wt2qs7drfq9qyyssqc6rztchzl4m7mlulrhlcajszcl9fan8908k9n5x7gmz8g8d6ht5pj4l8r0dushq6j5s8x7yv9a5klz0kfxwy8v6ze6adyrrp4wu0q0sq3t604x"
):
response.destination = (
"033b58d7681fe5dd2fb21fd741996cda5449616f77317dd1156b80128d6a71b807"
)
response.payment_hash = (
"7855c58afed86098f01ee7b0e857a5d2e51b1014adf176d82e38b38eecc5b7c6"
)
response.num_satoshis = 1731 response.num_satoshis = 1731
response.timestamp = 1699359597 response.timestamp = 1699359597
response.expiry = 450 response.expiry = 450
response.description = "Payment reference: xxxxxxxxxxxxxxxxxxxxxxx. This payment WILL FREEZE IN YOUR WALLET, check on RoboSats if the lock was successful. It will be unlocked (fail) unless you cheat or cancel unilaterally." response.description = "Payment reference: 7458199b-87ba-4da7-8438-8469f7899da5. This payment WILL FREEZE IN YOUR WALLET, check on RoboSats if the lock was successful. It will be unlocked (fail) unless you cheat or cancel unilaterally."
response.cltv_expiry = 650 response.cltv_expiry = 650
response.payment_addr = "\275\205\224\002\036h\322" response.payment_addr = '\275\205\224\016\363\325\262\201\306"8\022e\343\215\355\277\304\021\r\037l\202\023\314\353\334\265\002\036h\322'
response.num_msat = 1731000 response.num_msat = 1731000
def CancelInvoice(self, request): def CancelInvoice(self, request):
response = MagicMock() response = MagicMock()
if request == b"xU\305\212\306": if (
request
== b"xU\305\212\376\330`\230\360\036\347\260\350W\245\322\345\033\020\024\255\361v\330.8\263\216\354\305\267\306"
):
response = {} response = {}
return response return response
@ -69,9 +79,9 @@ class MockInvoicesStub:
def AddHoldInvoice(self, request): def AddHoldInvoice(self, request):
response = MagicMock() response = MagicMock()
if request.value == 1731: if request.value == 1731:
response.payment_request = "lntb17314....x" response.payment_request = "lntb17310n1pj552mdpp50p2utzh7mpsf3uq7u7cws4a96tj3kyq54hchdkpw8zecamx9klrqd2j2pshjmt9de6zqun9vejhyetwvdjn5gphxs6nsvfe893z6wphvfsj6dryvymj6wp5xvuz6wp5xcukvdec8yukgcf49cs9g6rfwvs8qcted4jkuapq2ay5cnpqgefy2326g5syjn3qt984253q2aq5cnz92skzqcmgv43kkgr0dcs9ymmzdafkzarnyp5kvgr5dpjjqmr0vd4jqampwvs8xatrvdjhxumxw4kzugzfwss8w6tvdssxyefqw4hxcmmrddjkggpgveskjmpfyp6kumr9wdejq7t0w5sxx6r9v96zqmmjyp3kzmnrv4kzqatwd9kxzar9wfskcmre9ccqz52xqzwzsp5hkzegrhn6kegr33z8qfxtcudaklugygdrakgyy7va0wt2qs7drfq9qyyssqc6rztchzl4m7mlulrhlcajszcl9fan8908k9n5x7gmz8g8d6ht5pj4l8r0dushq6j5s8x7yv9a5klz0kfxwy8v6ze6adyrrp4wu0q0sq3t604x"
response.add_index = 1 response.add_index = 1
response.payment_addr = b"\275\205\322" response.payment_addr = b'\275\205\224\016\363\325\262\201\306"8\022e\343\215\355\277\304\021\r\037l\202\023\314\353\334\265\002\036h\322'
def CancelInvoice(self, request): def CancelInvoice(self, request):
response = MagicMock() response = MagicMock()
@ -107,6 +117,9 @@ class MockSignerStub:
class MockVersionerStub: class MockVersionerStub:
def __init__(channel, other):
pass
def GetVersion(self, request): def GetVersion(self, request):
response = MagicMock() response = MagicMock()
response.commit = "v0.17.0-beta" response.commit = "v0.17.0-beta"

View File

@ -0,0 +1,61 @@
import json
from unittest.mock import patch
from decouple import config
from django.conf import settings
from django.contrib.auth.models import User
from django.test import Client, TestCase
from tests.mocks.cln import MockNodeStub
from tests.mocks.lnd import MockVersionerStub
FEE = config("FEE", cast=float, default=0.2)
NODE_ID = config("NODE_ID", cast=str, default="033b58d7......")
MAKER_FEE = FEE * config("FEE_SPLIT", cast=float, default=0.125)
TAKER_FEE = FEE * (1 - config("FEE_SPLIT", cast=float, default=0.125))
BOND_SIZE = config("BOND_SIZE", cast=float, default=3)
NOTICE_SEVERITY = config("NOTICE_SEVERITY", cast=str, default="none")
NOTICE_MESSAGE = config("NOTICE_MESSAGE", cast=str, default="")
class CoordinatorInfoTest(TestCase):
su_pass = "12345678"
su_name = config("ESCROW_USERNAME", cast=str, default="admin")
def setUp(self):
"""
Create a superuser. The superuser is the escrow party.
"""
self.client = Client()
User.objects.create_superuser(self.su_name, "super@user.com", self.su_pass)
@patch("api.lightning.cln.node_pb2_grpc.NodeStub", MockNodeStub)
@patch("api.lightning.lnd.verrpc_pb2_grpc.VersionerStub", MockVersionerStub)
def test_info(self):
path = "/api/info/"
response = self.client.get(path)
data = json.loads(response.content.decode())
self.assertEqual(response.status_code, 200)
self.assertEqual(data["num_public_buy_orders"], 0)
self.assertEqual(data["num_public_sell_orders"], 0)
self.assertEqual(data["book_liquidity"], 0)
self.assertEqual(data["active_robots_today"], 0)
self.assertEqual(data["last_day_nonkyc_btc_premium"], 0)
self.assertEqual(data["last_day_volume"], 0)
self.assertEqual(data["lifetime_volume"], 0)
self.assertEqual(data["lnd_version"], "v0.17.0-beta")
self.assertEqual(data["cln_version"], "v23.08")
self.assertEqual(
data["robosats_running_commit_hash"], "00000000000000000000 dev"
)
self.assertEqual(data["version"], settings.VERSION)
self.assertEqual(data["node_id"], NODE_ID)
self.assertEqual(data["network"], "testnet")
self.assertAlmostEqual(data["maker_fee"], MAKER_FEE)
self.assertAlmostEqual(data["taker_fee"], TAKER_FEE)
self.assertAlmostEqual(data["bond_size"], BOND_SIZE)
self.assertEqual(data["notice_severity"], NOTICE_SEVERITY)
self.assertEqual(data["notice_message"], NOTICE_MESSAGE)
self.assertEqual(data["current_swap_fee_rate"], 0)

View File

@ -9,6 +9,7 @@ from django.test import Client, TestCase
from api.models import Currency, Order from api.models import Currency, Order
from api.tasks import cache_market from api.tasks import cache_market
from tests.mocks.cln import MockHoldStub, MockNodeStub
from tests.mocks.lnd import ( from tests.mocks.lnd import (
MockInvoicesStub, MockInvoicesStub,
MockLightningStub, MockLightningStub,
@ -225,11 +226,13 @@ class TradeTest(TestCase):
) )
self.assertIsNone(data["taker"], "New order's taker is not null") self.assertIsNone(data["taker"], "New order's taker is not null")
@patch("api.lightning.lightning_pb2_grpc.LightningStub", MockLightningStub) @patch("api.lightning.cln.node_pb2_grpc.NodeStub", MockNodeStub)
@patch("api.lightning.invoices_pb2_grpc.InvoicesStub", MockInvoicesStub) @patch("api.lightning.cln.hold_pb2_grpc.HoldStub", MockHoldStub)
@patch("api.lightning.router_pb2_grpc.RouterStub", MockRouterStub) @patch("api.lightning.lnd.verrpc_pb2_grpc.VersionerStub", MockVersionerStub)
@patch("api.lightning.signer_pb2_grpc.SignerStub", MockSignerStub) @patch("api.lightning.lnd.lightning_pb2_grpc.LightningStub", MockLightningStub)
@patch("api.lightning.verrpc_pb2_grpc.VersionerStub", MockVersionerStub) @patch("api.lightning.lnd.invoices_pb2_grpc.InvoicesStub", MockInvoicesStub)
@patch("api.lightning.lnd.router_pb2_grpc.RouterStub", MockRouterStub)
@patch("api.lightning.lnd.signer_pb2_grpc.SignerStub", MockSignerStub)
def test_maker_bond_locked(self): def test_maker_bond_locked(self):
self.test_create_order( self.test_create_order(
robot_index=1, robot_index=1,