modules/websocket: more work on module boiler-plate and handshake
[sip-router] / modules / websocket / ws_handshake.c
1 /*
2  * $Id$
3  *
4  * Copyright (C) 2012 Crocodile RCS Ltd
5  *
6  * This file is part of Kamailio, a free SIP server.
7  *
8  * Kamailio is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 2 of the License, or
11  * (at your option) any later version
12  *
13  * Kamailio is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License 
19  * along with this program; if not, write to the Free Software 
20  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  *
22  */
23
24 #include <openssl/sha.h>
25
26 #include "../../basex.h"
27 #include "../../data_lump_rpl.h"
28 #include "../../dprint.h"
29 #include "../../locking.h"
30 #include "../../lib/kcore/kstats_wrapper.h"
31 #include "../../lib/kcore/cmpapi.h"
32 #include "../../lib/kmi/tree.h"
33 #include "../../parser/msg_parser.h"
34 #include "../sl/sl.h"
35 #include "ws_handshake.h"
36 #include "ws_mod.h"
37
38 #define WS_VERSION              (13)
39
40 static str str_sip = str_init("sip");
41 static str str_upgrade = str_init("upgrade");
42 static str str_websocket = str_init("websocket");
43 static str str_ws_guid = str_init("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
44
45 /* HTTP headers */
46 static str str_hdr_connection = str_init("Connection");
47 static str str_hdr_upgrade = str_init("Upgrade");
48 static str str_hdr_sec_websocket_accept = str_init("Sec-WebSocket-Accept");
49 static str str_hdr_sec_websocket_key = str_init("Sec-WebSocket-Key");
50 static str str_hdr_sec_websocket_protocol = str_init("Sec-WebSocket-Protocol");
51 static str str_hdr_sec_websocket_version = str_init("Sec-WebSocket-Version");
52 #define CONNECTION              (1<<0)
53 #define UPGRADE                 (1<<1)
54 #define SEC_WEBSOCKET_ACCEPT    (1<<2)
55 #define SEC_WEBSOCKET_KEY       (1<<3)
56 #define SEC_WEBSOCKET_PROTOCOL  (1<<4)
57 #define SEC_WEBSOCKET_VERSION   (1<<5)
58
59 #define REQUIRED_HEADERS        (CONNECTION | UPGRADE | SEC_WEBSOCKET_KEY\
60                                         | SEC_WEBSOCKET_PROTOCOL\
61                                         | SEC_WEBSOCKET_VERSION)
62
63 /* HTTP status text */
64 static str str_status_switching_protocols = str_init("Switching Protocols");
65 static str str_status_bad_request = str_init("Bad Request");
66 static str str_status_upgrade_required = str_init("Upgrade Required");
67 static str str_status_internal_server_error = str_init("Internal Server Error");
68 static str str_status_service_unavailable = str_init("Service Unavailable");
69
70 #define HDR_BUF_LEN             (256)
71 static char headers_buf[HDR_BUF_LEN];
72
73 #define KEY_BUF_LEN             (28)
74 static char key_buf[KEY_BUF_LEN];
75
76 static int ws_send_reply(sip_msg_t *msg, int code, str *reason, str *hdrs)
77 {
78         int cur_cons, max_cons;
79
80         if (hdrs && hdrs->len > 0)
81         {
82                 if (add_lump_rpl(msg, hdrs->s, hdrs->len, LUMP_RPL_HDR) == 0)
83                 {
84                         LM_ERR("inserting extra-headers lump\n");
85                         update_stat(ws_failed_handshakes, 1);
86                         return -1;
87                 }
88         }
89
90         if (ws_slb.freply(msg, code, reason) < 0)
91         {
92                 LM_ERR("sending reply\n");
93                 update_stat(ws_failed_handshakes, 1);
94                 return -1;
95         }
96
97         if (code == 101)
98         {
99                 update_stat(ws_successful_handshakes, 1);
100
101                 lock_get(ws_stats_lock);
102                 update_stat(ws_current_connections, 1);
103
104                 cur_cons = get_stat_val(ws_current_connections);
105                 max_cons = get_stat_val(ws_max_concurrent_connections);
106
107                 if (max_cons < cur_cons)
108                         update_stat(ws_max_concurrent_connections,
109                                                 cur_cons - max_cons);
110                 lock_release(ws_stats_lock);
111         }
112         else
113                 update_stat(ws_failed_handshakes, 1);
114
115         return 0;
116 }
117
118 int ws_handle_handshake(struct sip_msg *msg)
119 {
120         str key = {0, 0}, headers = {0, 0}, reply_key = {0, 0};
121         unsigned char sha1[20];
122         unsigned int hdr_flags = 0;
123         int version;
124         struct hdr_field *hdr = msg->headers;
125
126         if (*ws_enabled == 0)
127         {
128                 LM_INFO("disabled: bouncing handshake\n");
129                 ws_send_reply(msg, 503, &str_status_service_unavailable, NULL);
130                 return 0;
131         }
132
133         while (hdr != NULL)
134         {
135                 /* Decode and validate Connection */
136                 if (cmp_hdrname_strzn(&hdr->name,
137                                 str_hdr_connection.s,
138                                 str_hdr_connection.len) == 0)
139                 {
140                         strlower(&hdr->body);
141                         if (str_search(&hdr->body, &str_upgrade) != NULL)
142                         {
143                                 LM_INFO("found %.*s: %.*s\n",
144                                         hdr->name.len, hdr->name.s,
145                                         hdr->body.len, hdr->body.s);
146                                 hdr_flags |= CONNECTION;
147                         }
148                 }
149                 /* Decode and validate Upgrade */
150                 else if (cmp_hdrname_strzn(&hdr->name,
151                                 str_hdr_upgrade.s,
152                                 str_hdr_upgrade.len) == 0)
153                 {
154                         strlower(&hdr->body);
155                         if (str_search(&hdr->body, &str_websocket) != NULL)
156                         {
157                                 LM_INFO("found %.*s: %.*s\n",
158                                         hdr->name.len, hdr->name.s,
159                                         hdr->body.len, hdr->body.s);
160                                 hdr_flags |= UPGRADE;
161                         }
162                 }
163                 /* Decode and validate Sec-WebSocket-Key */
164                 else if (cmp_hdrname_strzn(&hdr->name,
165                                 str_hdr_sec_websocket_key.s, 
166                                 str_hdr_sec_websocket_key.len) == 0) 
167                 {
168                         if (hdr_flags & SEC_WEBSOCKET_KEY)
169                         {
170                                 LM_WARN("%.*s found multiple times\n",
171                                         hdr->name.len, hdr->name.s);
172                                 ws_send_reply(msg, 400,
173                                                 &str_status_bad_request, NULL);
174                                 return 0;
175                         }
176
177                         LM_INFO("found %.*s: %.*s\n",
178                                 hdr->name.len, hdr->name.s,
179                                 hdr->body.len, hdr->body.s);
180                         key = hdr->body;
181                         hdr_flags |= SEC_WEBSOCKET_KEY;
182                 }
183                 /* Decode and validate Sec-WebSocket-Protocol */
184                 else if (cmp_hdrname_strzn(&hdr->name,
185                                 str_hdr_sec_websocket_protocol.s,
186                                 str_hdr_sec_websocket_protocol.len) == 0)
187                 {
188                         strlower(&hdr->body);
189                         if (str_search(&hdr->body, &str_sip) != NULL)
190                         {
191                                 LM_INFO("found %.*s: %.*s\n",
192                                         hdr->name.len, hdr->name.s,
193                                         hdr->body.len, hdr->body.s);
194                                 hdr_flags |= SEC_WEBSOCKET_PROTOCOL;
195                         }
196                 }
197                 /* Decode and validate Sec-WebSocket-Version */
198                 else if (cmp_hdrname_strzn(&hdr->name,
199                                 str_hdr_sec_websocket_version.s,
200                                 str_hdr_sec_websocket_version.len) == 0)
201                 {
202                         if (hdr_flags & SEC_WEBSOCKET_VERSION)
203                         {
204                                 LM_WARN("%.*s found multiple times\n",
205                                         hdr->name.len, hdr->name.s);
206                                 ws_send_reply(msg, 400,
207                                                 &str_status_bad_request, NULL);
208                                 return 0;
209                         }
210
211                         str2sint(&hdr->body, &version);
212
213                         if (version != WS_VERSION)
214                         {
215                                 LM_WARN("Unsupported protocol version %.*s\n",
216                                         hdr->body.len, hdr->body.s);
217                                 headers.s = headers_buf;
218                                 headers.len = snprintf(headers.s, HDR_BUF_LEN,
219                                         "%.*s: %d\r\n",
220                                         str_hdr_sec_websocket_version.len,
221                                         str_hdr_sec_websocket_version.s,
222                                         WS_VERSION);
223                                 ws_send_reply(msg, 426,
224                                                 &str_status_upgrade_required,
225                                                 &headers);
226                                 return 0;
227                         }
228
229                         LM_INFO("found %.*s: %.*s\n",
230                                 hdr->name.len, hdr->name.s,
231                                 hdr->body.len, hdr->body.s);
232                         hdr_flags |= SEC_WEBSOCKET_VERSION;
233                 }
234
235                 hdr = hdr->next;
236         }
237
238         /* Final check that all required headers/values were found */
239         if (hdr_flags != REQUIRED_HEADERS)
240         {
241                 LM_WARN("required headers not present\n");
242                 headers.s = headers_buf;
243                 headers.len = snprintf(headers.s, HDR_BUF_LEN,
244                                         "%.*s: %.*s\r\n"
245                                         "%.*s: %d\r\n",
246                                         str_hdr_sec_websocket_protocol.len,
247                                         str_hdr_sec_websocket_protocol.s,
248                                         str_sip.len, str_sip.s,
249                                         str_hdr_sec_websocket_version.len,
250                                         str_hdr_sec_websocket_version.s,
251                                         WS_VERSION);
252                 ws_send_reply(msg, 400, &str_status_bad_request, NULL);
253                 return 0;
254         }
255
256         /* Construct reply_key */
257         reply_key.s = (char *) pkg_malloc(
258                                 (key.len + str_ws_guid.len) * sizeof(char)); 
259         if (reply_key.s == NULL)
260         {
261                 LM_ERR("allocating pkg memory\n");
262                 ws_send_reply(msg, 500, &str_status_internal_server_error,
263                                 NULL);
264                 return 0;
265         }
266         memcpy(reply_key.s, key.s, key.len);
267         memcpy(reply_key.s + key.len, str_ws_guid.s, str_ws_guid.len);
268         reply_key.len = key.len + str_ws_guid.len;
269         SHA1((const unsigned char *) reply_key.s, reply_key.len, sha1);
270         pkg_free(reply_key.s);
271         reply_key.s = key_buf;
272         reply_key.len = base64_enc(sha1, 20,
273                                 (unsigned char *) reply_key.s, KEY_BUF_LEN);
274
275         /* Build headers for reply */
276         headers.s = headers_buf;
277         headers.len = snprintf(headers.s, HDR_BUF_LEN,
278                         "%.*s: %.*s\r\n"
279                         "%.*s: %.*s\r\n"
280                         "%.*s: %.*s\r\n"
281                         "%.*s: %.*s\r\n",
282                         str_hdr_upgrade.len, str_hdr_upgrade.s,
283                         str_websocket.len, str_websocket.s,
284                         str_hdr_connection.len, str_hdr_connection.s,
285                         str_upgrade.len, str_upgrade.s,
286                         str_hdr_sec_websocket_accept.len,
287                         str_hdr_sec_websocket_accept.s, reply_key.len,
288                         reply_key.s, str_hdr_sec_websocket_protocol.len,
289                         str_hdr_sec_websocket_protocol.s, str_sip.len,
290                         str_sip.s);
291
292         /* TODO: make sure Kamailio core sends future requests on this
293                  connection directly to this module */
294
295         /* Send reply */
296         ws_send_reply(msg, 101, &str_status_switching_protocols, &headers);
297
298         return 0;
299 }
300
301 struct mi_root *ws_mi_disable(struct mi_root *cmd, void *param)
302 {
303         *ws_enabled = 0;
304         LM_WARN("disabling websockets - new connections will be dropped\n");
305         return init_mi_tree(200, MI_OK_S, MI_OK_LEN);
306 }
307
308 struct mi_root *ws_mi_enable(struct mi_root *cmd, void *param)
309 {
310         *ws_enabled = 1;
311         LM_WARN("enabling websockets\n");
312         return init_mi_tree(200, MI_OK_S, MI_OK_LEN);
313 }