View Javadoc
1   /*-
2    * #%L
3    * Spring Stomp Server
4    * ===============================================================
5    * Copyright (C) 2020 Brabenetz Harald, Austria
6    * ===============================================================
7    * Licensed under the Apache License, Version 2.0 (the "License");
8    * you may not use this file except in compliance with the License.
9    * You may obtain a copy of the License at
10   *
11   *      http://www.apache.org/licenses/LICENSE-2.0
12   *
13   * Unless required by applicable law or agreed to in writing, software
14   * distributed under the License is distributed on an "AS IS" BASIS,
15   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16   * See the License for the specific language governing permissions and
17   * limitations under the License.
18   * #L%
19   */
20  package net.brabenetz.app.springstompserver;
21  
22  import com.github.tomakehurst.wiremock.WireMockServer;
23  import com.github.tomakehurst.wiremock.client.WireMock;
24  import com.maciejwalkowiak.wiremock.spring.ConfigureWireMock;
25  import com.maciejwalkowiak.wiremock.spring.EnableWireMock;
26  import com.maciejwalkowiak.wiremock.spring.InjectWireMock;
27  import net.brabenetz.app.springstompserver.config.WebSocketInitLoadConfigProperties;
28  import net.brabenetz.app.springstompserver.testtools.WebSocketPayload;
29  import net.brabenetz.app.springstompserver.testtools.WebSocketStompSessionHandler;
30  import net.brabenetz.app.springstompserver.testtools.WebSocketUtils;
31  import org.junit.jupiter.api.AfterEach;
32  import org.junit.jupiter.api.BeforeEach;
33  import org.junit.jupiter.api.Test;
34  import org.slf4j.Logger;
35  import org.slf4j.LoggerFactory;
36  import org.springframework.beans.factory.annotation.Autowired;
37  import org.springframework.boot.test.context.SpringBootTest;
38  import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
39  import org.springframework.boot.web.server.LocalServerPort;
40  import org.springframework.core.env.Environment;
41  import org.springframework.messaging.simp.stomp.StompHeaders;
42  import org.springframework.messaging.simp.stomp.StompSession;
43  import org.springframework.messaging.simp.stomp.StompSession.Subscription;
44  import org.springframework.util.concurrent.ListenableFuture;
45  import org.springframework.web.socket.WebSocketHttpHeaders;
46  import org.springframework.web.socket.messaging.WebSocketStompClient;
47  
48  import java.util.ArrayList;
49  import java.util.List;
50  import java.util.Map;
51  import java.util.concurrent.CountDownLatch;
52  import java.util.concurrent.TimeUnit;
53  
54  import static org.assertj.core.api.Assertions.assertThat;
55  
56  @SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
57  @EnableWireMock(@ConfigureWireMock(name = "wiremock", property = "wiremock.baseurl"))
58  class SpringStompServerApplicationTests {
59  
60      private static final Logger LOG = LoggerFactory.getLogger(SpringStompServerApplicationTests.class);
61  
62      @LocalServerPort
63      private int port;
64  
65      @InjectWireMock("wiremock")
66      private WireMockServer wiremock;
67  
68      @Autowired
69      private Environment env;
70  
71      @Autowired
72      private WebSocketInitLoadConfigProperties webSocketInitLoadConfigProperties;
73  
74      private StompSession currentSession;
75      private List<Subscription> subscriptions;
76  
77      @BeforeEach
78      public void init() {
79          currentSession = null;
80          subscriptions = new ArrayList<>();
81  
82          String newProxyUrl = webSocketInitLoadConfigProperties.getProxyUrl()
83                  .replaceAll("http://localhost:8181", env.getProperty("wiremock.baseurl"));
84          webSocketInitLoadConfigProperties.setProxyUrl(newProxyUrl);
85  
86      }
87  
88      @AfterEach
89      public void cleanupSession() {
90          subscriptions.forEach(consumer -> consumer.unsubscribe());
91          if (currentSession != null) {
92              currentSession.disconnect();
93          }
94      }
95  
96      @Test
97      public void testStompWebsocketEndpoint() throws Exception {
98          String clientName = "STOMP-WebSocket";
99          String websocketEndpoint = "ws://localhost:" + port + "/websocket";
100         WebSocketStompClient stompClient = WebSocketUtils.createStompClient();
101 
102         testStompWebsocketEndpoint(clientName, websocketEndpoint, stompClient);
103     }
104 
105     @Test
106     public void testStompWebsocketEndpointWithSockJs() throws Exception {
107         String clientName = "SOCKJS-WebSocket";
108         String websocketEndpoint = "http://localhost:" + port + "/websocket";
109         WebSocketStompClient stompClient = WebSocketUtils.createStompOverSocketJsClient();
110 
111         testStompWebsocketEndpoint(clientName, websocketEndpoint, stompClient);
112     }
113 
114     private void testStompWebsocketEndpoint(String clientName, String websocketEndpoint, WebSocketStompClient client)
115             throws Exception {
116         // connect to the Websocket
117         WebSocketStompSessionHandlermpSessionHandler.html#WebSocketStompSessionHandler">WebSocketStompSessionHandler sessionHandler = new WebSocketStompSessionHandler(clientName);
118         ListenableFuture<StompSession> stompClientConnection = client.connect(websocketEndpoint, sessionHandler);
119         // wait until Session is created
120         currentSession = stompClientConnection.get(1, TimeUnit.SECONDS);
121 
122         // container to collect and sync the result
123         CountDownLatch doneSignal = new CountDownLatch(1);
124         List<WebSocketPayload<String>> messages = new ArrayList<>();
125 
126         // Register Subscription
127         subscriptions.add(currentSession.subscribe("/topic/test/1234", WebSocketUtils.createStompFrameHandler((StompHeaders headers, String payload) -> {
128             LOG.info(clientName + " FrameHandler got new Payload: " + payload);
129             messages.add(new WebSocketPayload<>(headers, payload));
130             doneSignal.countDown();
131             headers.forEach((k, v) -> {
132                 LOG.info(String.format(clientName + " HEADER: %s = %s", k, v));
133                 // Example Output:
134                 // STOMP-WebSocket HEADER: destination = [/topic/test/1234]
135                 // STOMP-WebSocket HEADER: content-type = [application/json]
136                 // STOMP-WebSocket HEADER: subscription = [0]
137                 // STOMP-WebSocket HEADER: message-id = [6ff3affb-eacd-ea53-b403-60dfdf8fcd88-0]
138                 // STOMP-WebSocket HEADER: content-length = [6]
139             });
140         })));
141 
142         // Send Message
143         currentSession.send("/topic/test/1234", "test");
144 
145         // Wait until Message is consumed
146         doneSignal.await(1, TimeUnit.SECONDS);
147         assertThat(doneSignal.getCount()).describedAs("CountDownLatch").isEqualTo(0);
148 
149         // Verify that consumed Message Payload.
150         assertThat(messages).hasSize(1);
151         assertThat(messages.get(0).getBody()).isEqualTo("test");
152         assertThat(messages.get(0).getHeaders().get("destination")).containsExactly("/topic/test/1234");
153         assertThat(messages.get(0).getHeaders().get("content-type")).containsExactly("application/json");
154     }
155 
156     @Test
157     public void testStompWebsocketEndpointWithInitialLoad() throws Exception {
158         // prepare Wiremock:
159         wiremock.stubFor(WireMock.get("/mocked-init-load/test/1234").willReturn(
160                 WireMock.ok("\"test\"")
161                         .withHeader("content-type", "application/json")));
162 
163         String clientName = "STOMP-WebSocket";
164         String websocketEndpoint = "ws://localhost:" + port + "/websocket";
165         WebSocketStompClient stompClient = WebSocketUtils.createStompClient();
166 
167         // connect to the Websocket
168         WebSocketStompSessionHandlermpSessionHandler.html#WebSocketStompSessionHandler">WebSocketStompSessionHandler sessionHandler = new WebSocketStompSessionHandler(clientName);
169         WebSocketHttpHeaders connectionHeaders = new WebSocketHttpHeaders();
170         connectionHeaders.add("x-test", "Dummy Connection Header");
171         ListenableFuture<StompSession> stompClientConnection = stompClient.connect(websocketEndpoint, connectionHeaders, sessionHandler);
172         // wait until Session is created
173         currentSession = stompClientConnection.get(1, TimeUnit.SECONDS);
174 
175         // container to collect and sync the result
176         CountDownLatch doneSignal = new CountDownLatch(1);
177         List<WebSocketPayload<String>> messages = new ArrayList<>();
178 
179         StompHeaders stompHeaders = new StompHeaders();
180         stompHeaders.add("x-stomp-header", "Dummy Subscr. Header");
181         stompHeaders.setDestination("/user/123456/topic/test/1234");
182 
183         subscriptions
184                 .add(currentSession.subscribe(stompHeaders, WebSocketUtils.createStompFrameHandler((StompHeaders headers, String payload) -> {
185             LOG.info(clientName + " FrameHandler got new Payload: " + payload);
186             messages.add(new WebSocketPayload<>(headers, payload));
187             doneSignal.countDown();
188             headers.forEach((k, v) -> {
189                 LOG.info(String.format(clientName + " HEADER: %s = %s", k, v));
190                 // Example Output:
191                 // STOMP-WebSocket HEADER: destination = [/topic/test/1234]
192                 // STOMP-WebSocket HEADER: content-type = [application/json]
193                 // STOMP-WebSocket HEADER: subscription = [0]
194                 // STOMP-WebSocket HEADER: message-id = [6ff3affb-eacd-ea53-b403-60dfdf8fcd88-0]
195                 // STOMP-WebSocket HEADER: content-length = [6]
196             });
197         })));
198 
199         // Wait until Message is consumed
200         doneSignal.await(1, TimeUnit.SECONDS);
201         assertThat(doneSignal.getCount()).describedAs("CountDownLatch").isEqualTo(0);
202 
203         // Verify that consumed Message Payload.
204         assertThat(messages).hasSize(1);
205         assertThat(messages.get(0).getBody()).isEqualTo("test");
206         Map<String, List<String>> headers = messages.get(0).getHeaders();
207         assertThat(headers.get("destination")).containsExactly("/user/123456/topic/test/1234");
208         assertThat(headers.get("content-type")).containsExactly("application/json");
209     }
210 }