1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  package net.brabenetz.app.springstompserver.testtools;
21  
22  import org.slf4j.Logger;
23  import org.slf4j.LoggerFactory;
24  import org.springframework.messaging.converter.MappingJackson2MessageConverter;
25  import org.springframework.messaging.simp.stomp.StompFrameHandler;
26  import org.springframework.messaging.simp.stomp.StompHeaders;
27  import org.springframework.web.socket.client.WebSocketClient;
28  import org.springframework.web.socket.client.standard.StandardWebSocketClient;
29  import org.springframework.web.socket.messaging.WebSocketStompClient;
30  import org.springframework.web.socket.sockjs.client.SockJsClient;
31  import org.springframework.web.socket.sockjs.client.Transport;
32  import org.springframework.web.socket.sockjs.client.WebSocketTransport;
33  
34  import java.lang.reflect.Type;
35  import java.util.ArrayList;
36  import java.util.List;
37  import java.util.function.BiConsumer;
38  
39  public class WebSocketUtils {
40  
41      static final Logger LOG = LoggerFactory.getLogger(WebSocketUtils.class);
42  
43      public static WebSocketStompClient createStompClient() {
44          WebSocketClient webSocketClient = new StandardWebSocketClient();
45          WebSocketStompClient stompClient = new WebSocketStompClient(webSocketClient);
46          stompClient.setMessageConverter(new MappingJackson2MessageConverter());
47          return stompClient;
48      }
49  
50      public static WebSocketStompClient createStompOverSocketJsClient() {
51          List<Transport> transports = new ArrayList<>();
52          transports.add(new WebSocketTransport(new StandardWebSocketClient()));
53          WebSocketClient transport = new SockJsClient(transports);
54          WebSocketStompClient stompClient = new WebSocketStompClient(transport);
55          stompClient.setMessageConverter(new MappingJackson2MessageConverter());
56          return stompClient;
57      }
58  
59      public static StompFrameHandler createStompFrameHandler(BiConsumer<StompHeaders, String> handleFrameConsumer) {
60          return new StompFrameHandler() {
61  
62              @Override
63              public Type getPayloadType(StompHeaders headers) {
64                  return String.class;
65              }
66  
67              @Override
68              public void handleFrame(StompHeaders headers, Object payload) {
69                  handleFrameConsumer.accept(headers, (String) payload);
70              }
71  
72          };
73      }
74  }