b46094f3157dfa393e82893a559a946354aff360
[sip-router] / modules / websocket / ws_handshake.c
1 /*
2  * $Id$
3  *
4  * Copyright (C) 2012 Crocodile RCS Ltd
5  *
6  * This file is part of Kamailio, a free SIP server.
7  *
8  * Kamailio is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 2 of the License, or
11  * (at your option) any later version
12  *
13  * Kamailio is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License 
19  * along with this program; if not, write to the Free Software 
20  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  *
22  */
23
24 #include <openssl/sha.h>
25
26 #include "../../basex.h"
27 #include "../../data_lump_rpl.h"
28 #include "../../dprint.h"
29 #include "../../lib/kcore/cmpapi.h"
30 #include "../../parser/msg_parser.h"
31 #include "../sl/sl.h"
32 #include "ws_handshake.h"
33 #include "ws_mod.h"
34
35 #define WS_VERSION              (13)
36
37 static str str_sip = str_init("sip");
38 static str str_websocket = str_init("websocket");
39 static str str_ws_guid = str_init("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
40
41 /* HTTP headers */
42 static str str_connection = str_init("Connection");
43 static str str_upgrade = str_init("Upgrade");
44 static str str_sec_websocket_accept = str_init("Sec-WebSocket-Accept");
45 static str str_sec_websocket_key = str_init("Sec-WebSocket-Key");
46 static str str_sec_websocket_protocol = str_init("Sec-WebSocket-Protocol");
47 static str str_sec_websocket_version = str_init("Sec-WebSocket-Init");
48 #define CONNECTION              (1<<0)
49 #define UPGRADE                 (1<<1)
50 #define SEC_WEBSOCKET_ACCEPT    (1<<2)
51 #define SEC_WEBSOCKET_KEY       (1<<3)
52 #define SEC_WEBSOCKET_PROTOCOL  (1<<4)
53 #define SEC_WEBSOCKET_VERSION   (1<<5)
54
55 #define REQUIRED_HEADERS        (CONNECTION | UPGRADE | SEC_WEBSOCKET_KEY\
56                                         | SEC_WEBSOCKET_PROTOCOL\
57                                         | SEC_WEBSOCKET_VERSION)
58
59 /* HTTP response text */
60 static str str_switching_protocols = str_init("Switching Protocols");
61 static str str_bad_request = str_init("Bad Request");
62 static str str_upgrade_required = str_init("Upgrade Required");
63 static str str_internal_server_error = str_init("Internal Server Error");
64
65 #define HDR_BUF_LEN             (256)
66 static char headers_buf[HDR_BUF_LEN];
67
68 #define KEY_BUF_LEN             (28)
69 static char key_buf[KEY_BUF_LEN];
70
71 static int ws_send_reply(sip_msg_t *msg, int code, str *reason, str *hdrs)
72 {
73         if (hdrs && hdrs->len > 0)
74         {
75                 if (add_lump_rpl(msg, hdrs->s, hdrs->len, LUMP_RPL_HDR) == 0)
76                 {
77                         LM_ERR("inserting extra-headers lump\n");
78                         return -1;
79                 }
80         }
81
82         if (ws_slb.freply(msg, code, reason) < 0)
83         {
84                 LM_ERR("sending reply\n");
85                 return -1;
86         }
87
88         return 0;
89 }
90
91 int ws_handle_handshake(struct sip_msg *msg)
92 {
93         str key = {0, 0}, headers = {0, 0}, reply_key = {0, 0};
94         unsigned char sha1[20];
95         unsigned int hdr_flags = 0;
96         int version;
97         struct hdr_field *hdr = msg->headers;
98
99         while (hdr != NULL)
100         {
101                 /* Decode and validate Connection */
102                 if (cmp_hdrname_strzn(&hdr->name,
103                                 str_connection.s,
104                                 str_connection.len) == 0)
105                 {
106                         /* TODO: validate Connection body */
107                         hdr_flags |= CONNECTION;
108                 }
109                 /* Decode and validate Upgrade */
110                 else if (cmp_hdrname_strzn(&hdr->name,
111                                 str_upgrade.s,
112                                 str_upgrade.len) == 0)
113                 {
114                         /* TODO: validate Upgrade body */
115                         hdr_flags |= UPGRADE;
116                 }
117                 /* Decode and validate Sec-WebSocket-Key */
118                 else if (cmp_hdrname_strzn(&hdr->name,
119                                 str_sec_websocket_key.s, 
120                                 str_sec_websocket_key.len) == 0) 
121                 {
122                         if (hdr_flags & SEC_WEBSOCKET_KEY)
123                         {
124                                 LM_WARN("%.*s found multiple times\n",
125                                         hdr->name.len, hdr->name.s);
126                                 ws_send_reply(msg, 400, &str_bad_request, NULL);
127                                 return 0;
128                         }
129
130                         key = hdr->body;
131                         hdr_flags |= SEC_WEBSOCKET_KEY;
132                 }
133                 /* Decode and validate Sec-WebSocket-Protocol */
134                 else if (cmp_hdrname_strzn(&hdr->name,
135                                 str_sec_websocket_protocol.s,
136                                 str_sec_websocket_protocol.len) == 0)
137                 {
138                         /* TODO: better validation of sip... */
139                         if (str_search(&hdr->body, &str_sip) != NULL)
140                                 hdr_flags |= SEC_WEBSOCKET_PROTOCOL;
141                 }
142                 /* Decode and validate Sec-WebSocket-Version */
143                 else if (cmp_hdrname_strzn(&hdr->name,
144                                 str_sec_websocket_version.s,
145                                 str_sec_websocket_version.len) == 0)
146                 {
147                         if (hdr_flags & SEC_WEBSOCKET_VERSION)
148                         {
149                                 LM_WARN("%.*s found multiple times\n",
150                                         hdr->name.len, hdr->name.s);
151                                 ws_send_reply(msg, 400, &str_bad_request, NULL);
152                                 return 0;
153                         }
154
155                         str2sint(&hdr->body, &version);
156
157                         if (version != WS_VERSION)
158                         {
159                                 LM_WARN("Unsupported protocol version %.*s\n",
160                                         hdr->body.len, hdr->body.s);
161                                 headers.s = headers_buf;
162                                 headers.len = snprintf(headers.s, HDR_BUF_LEN,
163                                         "%.*s: %d\r\n",
164                                         str_sec_websocket_version.len,
165                                         str_sec_websocket_version.s,
166                                         WS_VERSION);
167                                 ws_send_reply(msg, 426, &str_upgrade_required,
168                                                 &headers);
169                                 return 0;
170                         }
171
172                         hdr_flags |= SEC_WEBSOCKET_VERSION;
173                 }
174
175                 hdr = hdr->next;
176         }
177
178         /* Final check that all required headers/values were found */
179         if (hdr_flags != REQUIRED_HEADERS)
180         {
181                 LM_WARN("required headers not present\n");
182                 headers.s = headers_buf;
183                 headers.len = snprintf(headers.s, HDR_BUF_LEN,
184                                         "%.*s: %.*s\r\n"
185                                         "%.*s: %d\r\n",
186                                         str_sec_websocket_protocol.len, str_sec_websocket_protocol.s, str_sip.len, str_sip.s,
187                                         str_sec_websocket_version.len, str_sec_websocket_version.s, WS_VERSION);
188                 ws_send_reply(msg, 400, &str_bad_request, NULL);
189                 return 0;
190         }
191
192         /* Construct reply_key */
193         reply_key.s = (char *) pkg_malloc(
194                                 (key.len + str_ws_guid.len) * sizeof(char)); 
195         if (reply_key.s == NULL)
196         {
197                 LM_ERR("allocating pkg memory\n");
198                 ws_send_reply(msg, 500, &str_internal_server_error, NULL);
199                 return 0;
200         }
201         memcpy(reply_key.s, key.s, key.len);
202         memcpy(reply_key.s + key.len, str_ws_guid.s, str_ws_guid.len);
203         reply_key.len = key.len + str_ws_guid.len;
204         SHA1((const unsigned char *) reply_key.s, reply_key.len, sha1);
205         pkg_free(reply_key.s);
206         reply_key.s = key_buf;
207         reply_key.len = base64_enc(sha1, 20,
208                                 (unsigned char *) reply_key.s, KEY_BUF_LEN);
209
210         /* Build headers for reply */
211         headers.s = headers_buf;
212         headers.len = snprintf(headers.s, HDR_BUF_LEN,
213                         "%.*s: %.*s\r\n"
214                         "%.*s: %.*s\r\n"
215                         "%.*s: %.*s\r\n"
216                         "%.*s: %.*s\r\n",
217                         str_upgrade.len, str_upgrade.s, str_websocket.len, str_websocket.s,
218                         str_connection.len, str_connection.s, str_upgrade.len, str_upgrade.s,
219                         str_sec_websocket_accept.len, str_sec_websocket_accept.s, reply_key.len, reply_key.s,
220                         str_sec_websocket_protocol.len, str_sec_websocket_protocol.s, str_sip.len, str_sip.s);
221
222         /* TODO: make sure Kamailio core sends future requests on this
223                  connection directly to this module */
224
225         /* Send reply */
226         ws_send_reply(msg, 101, &str_switching_protocols, &headers);
227
228         return 0;
229 }