modules/websocket: Updated connection reuse and closing flags for WebSocket handshake.
[sip-router] / modules / websocket / ws_handshake.c
1 /*
2  * $Id$
3  *
4  * Copyright (C) 2012 Crocodile RCS Ltd
5  *
6  * This file is part of Kamailio, a free SIP server.
7  *
8  * Kamailio is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 2 of the License, or
11  * (at your option) any later version
12  *
13  * Kamailio is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License 
19  * along with this program; if not, write to the Free Software 
20  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  *
22  */
23
24 #include <openssl/sha.h>
25
26 #include "../../basex.h"
27 #include "../../data_lump_rpl.h"
28 #include "../../dprint.h"
29 #include "../../locking.h"
30 #include "../../str.h"
31 #include "../../tcp_conn.h"
32 #include "../../lib/kcore/kstats_wrapper.h"
33 #include "../../lib/kcore/cmpapi.h"
34 #include "../../lib/kmi/tree.h"
35 #include "../../mem/mem.h"
36 #include "../../parser/msg_parser.h"
37 #include "../sl/sl.h"
38 #include "../tls/tls_cfg.h"
39 #include "ws_conn.h"
40 #include "ws_handshake.h"
41 #include "ws_mod.h"
42
43 #define WS_VERSION              (13)
44
45 stat_var *ws_failed_handshakes;
46 stat_var *ws_successful_handshakes;
47
48 static str str_sip = str_init("sip");
49 static str str_upgrade = str_init("upgrade");
50 static str str_websocket = str_init("websocket");
51 static str str_ws_guid = str_init("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
52
53 /* HTTP headers */
54 static str str_hdr_connection = str_init("Connection");
55 static str str_hdr_upgrade = str_init("Upgrade");
56 static str str_hdr_sec_websocket_accept = str_init("Sec-WebSocket-Accept");
57 static str str_hdr_sec_websocket_key = str_init("Sec-WebSocket-Key");
58 static str str_hdr_sec_websocket_protocol = str_init("Sec-WebSocket-Protocol");
59 static str str_hdr_sec_websocket_version = str_init("Sec-WebSocket-Version");
60 #define CONNECTION              (1<<0)
61 #define UPGRADE                 (1<<1)
62 #define SEC_WEBSOCKET_ACCEPT    (1<<2)
63 #define SEC_WEBSOCKET_KEY       (1<<3)
64 #define SEC_WEBSOCKET_PROTOCOL  (1<<4)
65 #define SEC_WEBSOCKET_VERSION   (1<<5)
66
67 #define REQUIRED_HEADERS        (CONNECTION | UPGRADE | SEC_WEBSOCKET_KEY\
68                                         | SEC_WEBSOCKET_PROTOCOL\
69                                         | SEC_WEBSOCKET_VERSION)
70
71 /* HTTP status text */
72 static str str_status_switching_protocols = str_init("Switching Protocols");
73 static str str_status_bad_request = str_init("Bad Request");
74 static str str_status_upgrade_required = str_init("Upgrade Required");
75 static str str_status_internal_server_error = str_init("Internal Server Error");
76 static str str_status_service_unavailable = str_init("Service Unavailable");
77
78 #define HDR_BUF_LEN             (256)
79 static char headers_buf[HDR_BUF_LEN];
80
81 #define KEY_BUF_LEN             (28)
82 static char key_buf[KEY_BUF_LEN];
83
84 static int ws_send_reply(sip_msg_t *msg, int code, str *reason, str *hdrs)
85 {
86         if (hdrs && hdrs->len > 0)
87         {
88                 if (add_lump_rpl(msg, hdrs->s, hdrs->len, LUMP_RPL_HDR) == 0)
89                 {
90                         LM_ERR("inserting extra-headers lump\n");
91                         update_stat(ws_failed_handshakes, 1);
92                         return -1;
93                 }
94         }
95
96         if (ws_slb.freply(msg, code, reason) < 0)
97         {
98                 LM_ERR("sending reply\n");
99                 update_stat(ws_failed_handshakes, 1);
100                 return -1;
101         }
102
103         update_stat(
104                 code == 101 ? ws_successful_handshakes : ws_failed_handshakes,
105                 1);
106
107         return 0;
108 }
109
110 int ws_handle_handshake(struct sip_msg *msg)
111 {
112         str key = {0, 0}, headers = {0, 0}, reply_key = {0, 0};
113         unsigned char sha1[20];
114         unsigned int hdr_flags = 0;
115         int version;
116         struct hdr_field *hdr = msg->headers;
117         struct tcp_connection *con;
118         ws_connection_t *wsc;
119
120         /* Make sure that the connection is closed after the response _and_
121            the existing connection (from the request) is reused for the
122            response.  The close flag will be unset later if the handshake is
123            successful. */
124         msg->rpl_send_flags.f |= SND_F_CON_CLOSE;
125         msg->rpl_send_flags.f |= SND_F_FORCE_CON_REUSE;
126
127         if (*ws_enabled == 0)
128         {
129                 LM_INFO("disabled: bouncing handshake\n");
130                 ws_send_reply(msg, 503, &str_status_service_unavailable,
131                                 NULL);
132                 return 0;
133         }
134
135         /* Retrieve TCP/TLS connection */
136         if ((con = tcpconn_get(msg->rcv.proto_reserved1, 0, 0, 0, 0)) == NULL)
137         {
138                 LM_ERR("retrieving connection\n");
139                 ws_send_reply(msg, 500, &str_status_internal_server_error,
140                                 NULL);
141                 return 0;
142         }
143
144         if (con->type != PROTO_TCP && con->type != PROTO_TLS)
145         {
146                 LM_ERR("unsupported transport: %d", con->type);
147                 return 0;
148         }
149
150         /* Process HTTP headers */
151         while (hdr != NULL)
152         {
153                 /* Decode and validate Connection */
154                 if (cmp_hdrname_strzn(&hdr->name,
155                                 str_hdr_connection.s,
156                                 str_hdr_connection.len) == 0)
157                 {
158                         strlower(&hdr->body);
159                         if (str_search(&hdr->body, &str_upgrade) != NULL)
160                         {
161                                 LM_DBG("found %.*s: %.*s\n",
162
163                                         hdr->name.len, hdr->name.s,
164                                         hdr->body.len, hdr->body.s);
165                                 hdr_flags |= CONNECTION;
166                         }
167                 }
168                 /* Decode and validate Upgrade */
169                 else if (cmp_hdrname_strzn(&hdr->name,
170                                 str_hdr_upgrade.s,
171                                 str_hdr_upgrade.len) == 0)
172                 {
173                         strlower(&hdr->body);
174                         if (str_search(&hdr->body, &str_websocket) != NULL)
175                         {
176                                 LM_DBG("found %.*s: %.*s\n",
177                                         hdr->name.len, hdr->name.s,
178                                         hdr->body.len, hdr->body.s);
179                                 hdr_flags |= UPGRADE;
180                         }
181                 }
182                 /* Decode and validate Sec-WebSocket-Key */
183                 else if (cmp_hdrname_strzn(&hdr->name,
184                                 str_hdr_sec_websocket_key.s, 
185                                 str_hdr_sec_websocket_key.len) == 0) 
186                 {
187                         if (hdr_flags & SEC_WEBSOCKET_KEY)
188                         {
189                                 LM_WARN("%.*s found multiple times\n",
190                                         hdr->name.len, hdr->name.s);
191                                 ws_send_reply(msg, 400,
192                                                 &str_status_bad_request, NULL);
193                                 return 0;
194                         }
195
196                         LM_DBG("found %.*s: %.*s\n",
197                                 hdr->name.len, hdr->name.s,
198                                 hdr->body.len, hdr->body.s);
199                         key = hdr->body;
200                         hdr_flags |= SEC_WEBSOCKET_KEY;
201                 }
202                 /* Decode and validate Sec-WebSocket-Protocol */
203                 else if (cmp_hdrname_strzn(&hdr->name,
204                                 str_hdr_sec_websocket_protocol.s,
205                                 str_hdr_sec_websocket_protocol.len) == 0)
206                 {
207                         strlower(&hdr->body);
208                         if (str_search(&hdr->body, &str_sip) != NULL)
209                         {
210                                 LM_DBG("found %.*s: %.*s\n",
211                                         hdr->name.len, hdr->name.s,
212                                         hdr->body.len, hdr->body.s);
213                                 hdr_flags |= SEC_WEBSOCKET_PROTOCOL;
214                         }
215                 }
216                 /* Decode and validate Sec-WebSocket-Version */
217                 else if (cmp_hdrname_strzn(&hdr->name,
218                                 str_hdr_sec_websocket_version.s,
219                                 str_hdr_sec_websocket_version.len) == 0)
220                 {
221                         if (hdr_flags & SEC_WEBSOCKET_VERSION)
222                         {
223                                 LM_WARN("%.*s found multiple times\n",
224                                         hdr->name.len, hdr->name.s);
225                                 ws_send_reply(msg, 400,
226                                                 &str_status_bad_request, NULL);
227                                 return 0;
228                         }
229
230                         str2sint(&hdr->body, &version);
231
232                         if (version != WS_VERSION)
233                         {
234                                 LM_WARN("Unsupported protocol version %.*s\n",
235                                         hdr->body.len, hdr->body.s);
236                                 headers.s = headers_buf;
237                                 headers.len = snprintf(headers.s, HDR_BUF_LEN,
238                                         "%.*s: %d\r\n",
239                                         str_hdr_sec_websocket_version.len,
240                                         str_hdr_sec_websocket_version.s,
241                                         WS_VERSION);
242                                 ws_send_reply(msg, 426,
243                                                 &str_status_upgrade_required,
244                                                 &headers);
245                                 return 0;
246                         }
247
248                         LM_DBG("found %.*s: %.*s\n",
249                                 hdr->name.len, hdr->name.s,
250                                 hdr->body.len, hdr->body.s);
251                         hdr_flags |= SEC_WEBSOCKET_VERSION;
252                 }
253
254                 hdr = hdr->next;
255         }
256
257         /* Final check that all required headers/values were found */
258         if (hdr_flags != REQUIRED_HEADERS)
259         {
260                 LM_WARN("required headers not present\n");
261                 headers.s = headers_buf;
262                 headers.len = snprintf(headers.s, HDR_BUF_LEN,
263                                         "%.*s: %.*s\r\n"
264                                         "%.*s: %d\r\n",
265                                         str_hdr_sec_websocket_protocol.len,
266                                         str_hdr_sec_websocket_protocol.s,
267                                         str_sip.len, str_sip.s,
268                                         str_hdr_sec_websocket_version.len,
269                                         str_hdr_sec_websocket_version.s,
270                                         WS_VERSION);
271                 ws_send_reply(msg, 400, &str_status_bad_request, NULL);
272                 return 0;
273         }
274
275         /* Construct reply_key */
276         reply_key.s = (char *) pkg_malloc(
277                                 (key.len + str_ws_guid.len) * sizeof(char)); 
278         if (reply_key.s == NULL)
279         {
280                 LM_ERR("allocating pkg memory\n");
281                 ws_send_reply(msg, 500, &str_status_internal_server_error,
282                                 NULL);
283                 return 0;
284         }
285         memcpy(reply_key.s, key.s, key.len);
286         memcpy(reply_key.s + key.len, str_ws_guid.s, str_ws_guid.len);
287         reply_key.len = key.len + str_ws_guid.len;
288         SHA1((const unsigned char *) reply_key.s, reply_key.len, sha1);
289         pkg_free(reply_key.s);
290         reply_key.s = key_buf;
291         reply_key.len = base64_enc(sha1, 20,
292                                 (unsigned char *) reply_key.s, KEY_BUF_LEN);
293
294         /* Add the connection to the WebSocket connection table */
295         wsconn_add(msg->rcv.proto_reserved1);
296
297         /* Make sure Kamailio core sends future messages on this connection
298            directly to this module */
299         if (con->type == PROTO_TLS)
300                 con->type = con->rcv.proto = PROTO_WSS;
301         else
302                 con->type = con->rcv.proto = PROTO_WS;
303
304         /* Now Kamailio is ready to receive WebSocket frames build and send a
305            101 reply */
306         headers.s = headers_buf;
307         headers.len = snprintf(headers.s, HDR_BUF_LEN,
308                         "%.*s: %.*s\r\n"
309                         "%.*s: %.*s\r\n"
310                         "%.*s: %.*s\r\n"
311                         "%.*s: %.*s\r\n",
312                         str_hdr_upgrade.len, str_hdr_upgrade.s,
313                         str_websocket.len, str_websocket.s,
314                         str_hdr_connection.len, str_hdr_connection.s,
315                         str_upgrade.len, str_upgrade.s,
316                         str_hdr_sec_websocket_accept.len,
317                         str_hdr_sec_websocket_accept.s, reply_key.len,
318                         reply_key.s, str_hdr_sec_websocket_protocol.len,
319                         str_hdr_sec_websocket_protocol.s, str_sip.len,
320                         str_sip.s);
321         msg->rpl_send_flags.f &= ~SND_F_CON_CLOSE;
322         if (ws_send_reply(msg, 101,
323                                 &str_status_switching_protocols, &headers) < 0)
324                 if ((wsc = wsconn_get(msg->rcv.proto_reserved1)) != NULL)
325                         wsconn_rm(wsc);
326
327         return 0;
328 }
329
330 struct mi_root *ws_mi_disable(struct mi_root *cmd, void *param)
331 {
332         *ws_enabled = 0;
333         LM_WARN("disabling websockets - new connections will be dropped\n");
334         return init_mi_tree(200, MI_OK_S, MI_OK_LEN);
335 }
336
337 struct mi_root *ws_mi_enable(struct mi_root *cmd, void *param)
338 {
339         *ws_enabled = 1;
340         LM_WARN("enabling websockets\n");
341         return init_mi_tree(200, MI_OK_S, MI_OK_LEN);
342 }