Add chat REST tests and improve doc specs

This commit is contained in:
Reckless_Satoshi 2023-11-28 17:21:27 +00:00
parent 92b041cb31
commit ddb91b1cc4
No known key found for this signature in database
GPG Key ID: 9C4585B561315571
4 changed files with 201 additions and 25 deletions

View File

@ -3,24 +3,73 @@ from rest_framework import serializers
from chat.models import Message
class ChatSerializer(serializers.ModelSerializer):
class OutMessagesSerializer(serializers.ModelSerializer):
time = serializers.DateTimeField(source="created_at")
message = serializers.CharField(source="PGP_message")
nick = serializers.CharField(source="sender")
class Meta:
model = Message
fields = (
"index",
"sender",
"PGP_message",
"created_at",
"time",
"message",
"nick",
)
depth = 0
class PostMessageSerializer(serializers.ModelSerializer):
class InMessageSerializer(serializers.ModelSerializer):
class Meta:
model = Message
fields = ("PGP_message", "order_id", "offset")
fields = (
"index",
"created_at",
"PGP_message",
"sender",
)
depth = 0
class ChatSerializer(serializers.ModelSerializer):
offset = serializers.IntegerField(
allow_null=True,
default=None,
required=False,
min_value=0,
help_text="Offset for message index to get as response",
)
peer_pubkey = serializers.CharField(
required=False,
help_text="Your peer's public PGP key",
)
peer_connected = serializers.BooleanField(
required=False,
help_text="Whether your peer has connected recently to the chatroom",
)
messages = serializers.ListField(child=OutMessagesSerializer(), required=False)
class Meta:
model = Message
fields = ("messages", "offset", "peer_connected", "peer_pubkey")
depth = 0
class PostMessageSerializer(serializers.ModelSerializer):
PGP_message = serializers.CharField(
required=True,
help_text="A new chat message",
)
order_id = serializers.IntegerField(
required=True,
min_value=0,
help_text="Your peer's public key",
)
offset = serializers.IntegerField(
allow_null=True,
default=None,
@ -29,9 +78,7 @@ class PostMessageSerializer(serializers.ModelSerializer):
help_text="Offset for message index to get as response",
)
order_id = serializers.IntegerField(
allow_null=False,
required=True,
min_value=0,
help_text="Order ID of chatroom",
)
class Meta:
model = Message
fields = ("PGP_message", "order_id", "offset")
depth = 0

View File

@ -4,6 +4,7 @@ from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.models import User
from django.utils import timezone
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import status, viewsets
from rest_framework.authentication import TokenAuthentication
from rest_framework.permissions import IsAuthenticated
@ -12,11 +13,11 @@ from rest_framework.response import Response
from api.models import Order
from api.tasks import send_notification
from chat.models import ChatRoom, Message
from chat.serializers import ChatSerializer, PostMessageSerializer
from chat.serializers import ChatSerializer, InMessageSerializer, PostMessageSerializer
class ChatView(viewsets.ViewSet):
serializer_class = PostMessageSerializer
serializer_class = ChatSerializer
authentication_classes = [TokenAuthentication]
permission_classes = [IsAuthenticated]
@ -26,6 +27,15 @@ class ChatView(viewsets.ViewSet):
order__status__in=[Order.Status.CHA, Order.Status.FSE]
)
@extend_schema(
request=ChatSerializer,
parameters=[
OpenApiParameter(
name="order_id", location=OpenApiParameter.QUERY, type=int
),
OpenApiParameter(name="offset", location=OpenApiParameter.QUERY, type=int),
],
)
def get(self, request, format=None):
"""
Returns chat messages for an order with an index higher than `offset`.
@ -87,7 +97,7 @@ class ChatView(viewsets.ViewSet):
messages = []
for message in queryset:
d = ChatSerializer(message).data
d = InMessageSerializer(message).data
# Re-serialize so the response is identical to the consumer message
data = {
"index": d["index"],
@ -105,12 +115,13 @@ class ChatView(viewsets.ViewSet):
return Response(response, status.HTTP_200_OK)
@extend_schema(request=PostMessageSerializer, responses=ChatSerializer)
def post(self, request, format=None):
"""
Adds one new message to the chatroom.
Adds one new message to the chatroom. If `offset` is given, will return every new message as well.
"""
serializer = self.serializer_class(data=request.data)
serializer = PostMessageSerializer(data=request.data)
# Return bad request if serializer is not valid
if not serializer.is_valid():
context = {"bad_request": "Invalid serializer"}
@ -175,8 +186,10 @@ class ChatView(viewsets.ViewSet):
# Send websocket message
if chatroom.maker == request.user:
peer_connected = chatroom.taker_connected
peer_public_key = order.taker.robot.public_key
elif chatroom.taker == request.user:
peer_connected = chatroom.maker_connected
peer_public_key = order.maker.robot.public_key
channel_layer = get_channel_layer()
async_to_sync(channel_layer.group_send)(
@ -197,7 +210,7 @@ class ChatView(viewsets.ViewSet):
queryset = Message.objects.filter(order=order, index__gt=offset)
messages = []
for message in queryset:
d = ChatSerializer(message).data
d = InMessageSerializer(message).data
# Re-serialize so the response is identical to the consumer message
data = {
"index": d["index"],
@ -207,7 +220,11 @@ class ChatView(viewsets.ViewSet):
}
messages.append(data)
response = {"peer_connected": peer_connected, "messages": messages}
response = {
"peer_connected": peer_connected,
"messages": messages,
"peer_pubkey": peer_public_key,
}
else:
response = {}

View File

@ -62,6 +62,15 @@ paths:
get:
operationId: chat_retrieve
description: Returns chat messages for an order with an index higher than `offset`.
parameters:
- in: query
name: offset
schema:
type: integer
- in: query
name: order_id
schema:
type: integer
tags:
- chat
security:
@ -71,11 +80,12 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/PostMessage'
$ref: '#/components/schemas/Chat'
description: ''
post:
operationId: chat_create
description: Adds one new message to the chatroom.
description: Adds one new message to the chatroom. If `offset` is given, will
return every new message as well.
tags:
- chat
requestBody:
@ -97,7 +107,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/PostMessage'
$ref: '#/components/schemas/Chat'
description: ''
/api/historical/:
get:
@ -910,6 +920,24 @@ components:
BlankEnum:
enum:
- ''
Chat:
type: object
properties:
messages:
type: array
items:
$ref: '#/components/schemas/OutMessages'
offset:
type: integer
minimum: 0
nullable: true
description: Offset for message index to get as response
peer_connected:
type: boolean
description: Whether your peer has connected recently to the chatroom
peer_pubkey:
type: string
description: Your peer's public PGP key
ClaimReward:
type: object
properties:
@ -1589,6 +1617,24 @@ components:
- expires_at
- id
- type
OutMessages:
type: object
properties:
index:
type: integer
maximum: 2147483647
minimum: 0
time:
type: string
format: date-time
message:
type: string
nick:
type: string
required:
- message
- nick
- time
PlatformSummary:
type: object
properties:
@ -1623,18 +1669,18 @@ components:
properties:
PGP_message:
type: string
nullable: true
maxLength: 5000
description: A new chat message
order_id:
type: integer
minimum: 0
description: Order ID of chatroom
description: Your peer's public key
offset:
type: integer
minimum: 0
nullable: true
description: Offset for message index to get as response
required:
- PGP_message
- order_id
RatingEnum:
enum:

View File

@ -758,6 +758,72 @@ class TradeTest(BaseAPITestCase):
)
self.assertEqual(data["expiry_reason"], Order.ExpiryReasons.NINVOI)
def test_chat(self):
"""
Tests the chatting REST functionality
"""
path = reverse("chat")
message = (
"Example message string. Note clients will verify expect only PGP messages."
)
# Run a successful trade
trade = Trade(self.client)
trade.publish_order()
trade.take_order()
trade.lock_taker_bond()
trade.lock_escrow(trade.taker_index)
trade.submit_payout_invoice(trade.maker_index)
params = f"?order_id={trade.order_id}"
maker_headers = trade.get_robot_auth(trade.maker_index)
taker_headers = trade.get_robot_auth(trade.taker_index)
maker_nick = read_file(f"tests/robots/{trade.maker_index}/nickname")
taker_nick = read_file(f"tests/robots/{trade.taker_index}/nickname")
# Get empty chatroom as maker
response = self.client.get(path + params, **maker_headers)
self.assertResponse(response)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["messages"], [])
self.assertTrue(response.json()["peer_connected"])
# Get empty chatroom as taker
response = self.client.get(path + params, **taker_headers)
self.assertResponse(response)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["messages"], [])
self.assertTrue(response.json()["peer_connected"])
# Post new message as maker
body = {"PGP_message": message, "order_id": trade.order_id, "offset": 0}
response = self.client.post(path + params, data=body, **maker_headers)
self.assertResponse(response)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["messages"][0]["message"], message)
self.assertTrue(response.json()["peer_connected"])
# Post new message as taker without offset, so response should not have messages.
body = {"PGP_message": message + " 2", "order_id": trade.order_id}
response = self.client.post(path + params, data=body, **taker_headers)
self.assertResponse(response)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {}) # Nothing in the response
# Get the two chatroom messages as maker
response = self.client.get(path + params, **maker_headers)
self.assertResponse(response)
self.assertEqual(response.status_code, 200)
self.assertTrue(response.json()["peer_connected"])
self.assertEqual(response.json()["messages"][0]["message"], message)
self.assertEqual(response.json()["messages"][1]["message"], message + " 2")
self.assertEqual(response.json()["messages"][0]["nick"], maker_nick)
self.assertEqual(response.json()["messages"][1]["nick"], taker_nick)
# Cancel order to avoid leaving pending HTLCs after a successful test
trade.cancel_order(trade.maker_index)
trade.cancel_order(trade.taker_index)
def test_ticks(self):
"""
Tests the historical ticks serving endpoint after creating a contract