auth_diameter: avoid passing large structs as params and better error handling
[sip-router] / src / modules / auth_diameter / message.c
1 /*
2  * Copyright (C) 2002-2003 FhG Fokus
3  *
4  * This file is part of disc, a free diameter server/client.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License 
17  * along with this program; if not, write to the Free Software 
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
19  */
20
21
22 #include <stdlib.h>
23 #include <string.h>
24 #include <sys/types.h>
25 #include <netinet/in.h>
26
27
28 #include "../../core/mem/shm_mem.h"
29 #include "../../core/dprint.h"
30 #include "diameter_msg.h"
31
32 #define get_3bytes(_b) \
33         ((((unsigned int)(_b)[0])<<16)|(((unsigned int)(_b)[1])<<8)|\
34         (((unsigned int)(_b)[2])))
35
36 #define get_4bytes(_b) \
37         ((((unsigned int)(_b)[0])<<24)|(((unsigned int)(_b)[1])<<16)|\
38         (((unsigned int)(_b)[2])<<8)|(((unsigned int)(_b)[3])))
39
40 #define set_3bytes(_b,_v) \
41         {(_b)[0]=((_v)&0x00ff0000)>>16;(_b)[1]=((_v)&0x0000ff00)>>8;\
42         (_b)[2]=((_v)&0x000000ff);}
43
44 #define set_4bytes(_b,_v) \
45         {(_b)[0]=((_v)&0xff000000)>>24;(_b)[1]=((_v)&0x00ff0000)>>16;\
46         (_b)[2]=((_v)&0x0000ff00)>>8;(_b)[3]=((_v)&0x000000ff);}
47
48 #define to_32x_len( _len_ ) \
49         ( (_len_)+(((_len_)&3)?4-((_len_)&3):0) )
50
51
52 /* from a AAAMessage structure, a buffer to be send is build
53  */
54 AAAReturnCode AAABuildMsgBuffer( AAAMessage *msg )
55 {
56         unsigned char *p;
57         AAA_AVP       *avp;
58
59         /* first let's compute the length of the buffer */
60         msg->buf.len = AAA_MSG_HDR_SIZE; /* AAA message header size */
61         /* count and add the avps */
62         for(avp=msg->avpList.head;avp;avp=avp->next) {
63                 msg->buf.len += AVP_HDR_SIZE(avp->flags)+ to_32x_len( avp->data.len );
64         }
65
66 //      LM_DBG("xxxx len=%d\n",msg->buf.len);
67         /* allocate some memory */
68         msg->buf.s = (char*)ad_malloc( msg->buf.len );
69         if (!msg->buf.s) {
70                 LM_ERR(" no more pkg memory!\n");
71                 goto error;
72         }
73         memset(msg->buf.s, 0, msg->buf.len);
74
75         /* fill in the buffer */
76         p = (unsigned char*)msg->buf.s;
77         /* DIAMETER HEADER */
78         /* message length */
79         ((unsigned int*)p)[0] =htonl(msg->buf.len);
80         /* Diameter Version */
81         *p = 1;
82         p += VER_SIZE + MESSAGE_LENGTH_SIZE;
83         /* command code */
84         ((unsigned int*)p)[0] = htonl(msg->commandCode);
85         /* flags */
86         *p = (unsigned char)msg->flags;
87         p += FLAGS_SIZE + COMMAND_CODE_SIZE;
88         /* application-ID */
89         ((unsigned int*)p)[0] = htonl(msg->applicationId);
90         p += APPLICATION_ID_SIZE;
91         /* hop by hop id */
92         ((unsigned int*)p)[0] = msg->hopbyhopId;
93         p += HOP_BY_HOP_IDENTIFIER_SIZE;
94         /* end to end id */
95         ((unsigned int*)p)[0] = msg->endtoendId;
96         p += END_TO_END_IDENTIFIER_SIZE;
97
98         /* AVPS */
99         for(avp=msg->avpList.head;avp;avp=avp->next) {
100                 /* AVP HEADER */
101                 /* avp code */
102                 set_4bytes(p,avp->code);
103                 p +=4;
104                 /* flags */
105                 (*p++) = (unsigned char)avp->flags;
106                 /* avp length */
107                 set_3bytes(p, (AVP_HDR_SIZE(avp->flags)+avp->data.len) );
108                 p += 3;
109                 /* vendor id */
110                 if ((avp->flags&0x80)!=0) {
111                         set_4bytes(p,avp->vendorId);
112                         p +=4;
113                 }
114                 /* data */
115                 memcpy( p, avp->data.s, avp->data.len);
116                 p += to_32x_len( avp->data.len );
117         }
118
119         if ((char*)p-msg->buf.s!=msg->buf.len) {
120                 LM_ERR("mismatch between len and buf!\n");
121                 ad_free( msg->buf.s );
122                 msg->buf.s = 0;
123                 msg->buf.len = 0;
124                 goto error;
125         }
126 //      LM_DBG("Message: %.*s\n", msg->buf.len, msg->buf.s);
127         return AAA_ERR_SUCCESS;
128 error:
129         return -1;
130 }
131
132
133
134 /* frees a message allocated through AAANewMessage()
135  */
136 AAAReturnCode  AAAFreeMessage(AAAMessage **msg)
137 {
138         AAA_AVP *avp_t;
139         AAA_AVP *avp;
140
141         /* param check */
142         if (!msg || !(*msg))
143                 goto done;
144
145         /* free the avp list */
146         avp = (*msg)->avpList.head;
147         while (avp) {
148                 avp_t = avp;
149                 avp = avp->next;
150                 /*free the avp*/
151                 AAAFreeAVP(&avp_t);
152         }
153
154         /* free the buffer (if any) */
155         if ( (*msg)->buf.s )
156                 ad_free( (*msg)->buf.s );
157
158         /* free the AAA msg */
159         ad_free(*msg);
160         msg = 0;
161
162 done:
163         return AAA_ERR_SUCCESS;
164 }
165
166
167
168 /* Sets the proper result_code into the Result-Code AVP; thus avp must already
169  * exists into the reply message */
170 AAAReturnCode  AAASetMessageResultCode(
171         AAAMessage *message,
172         AAAResultCode resultCode)
173 {
174         if ( !is_req(message) && message->res_code) {
175                 *((unsigned int*)(message->res_code->data.s)) = htonl(resultCode);
176                 return AAA_ERR_SUCCESS;
177         }
178         return AAA_ERR_FAILURE;
179 }
180
181
182
183 /* This function convert message to message structure */
184 AAAMessage* AAATranslateMessage( unsigned char* source, unsigned int sourceLen,
185                                                                                                                         int attach_buf)
186 {
187         unsigned char *ptr;
188         AAAMessage    *msg;
189         unsigned char version;
190         unsigned int  msg_len;
191         AAA_AVP       *avp;
192         unsigned int  avp_code;
193         unsigned char avp_flags;
194         unsigned int  avp_len;
195         unsigned int  avp_vendorID;
196         unsigned int  avp_data_len;
197
198         /* check the params */
199         if( !source || !sourceLen || sourceLen<AAA_MSG_HDR_SIZE) {
200                 LM_ERR(" invalid buffered received!\n");
201                 goto error;
202         }
203
204         /* inits */
205         msg = 0;
206         avp = 0;
207         ptr = source;
208
209         /* alloc a new message structure */
210         msg = (AAAMessage*)ad_malloc(sizeof(AAAMessage));
211         if (!msg) {
212                 LM_ERR(" no more free memory!!\n");
213                 goto error;
214         }
215         memset(msg,0,sizeof(AAAMessage));
216
217         /* get the version */
218         version = (unsigned char)*ptr;
219         ptr += VER_SIZE;
220         if (version!=1) {
221                 LM_ERR(" invalid version [%d]in AAA msg\n",version);
222                 goto error;
223         }
224
225         /* message length */
226         msg_len = get_3bytes( ptr );
227         ptr += MESSAGE_LENGTH_SIZE;
228         if (msg_len>sourceLen) {
229                 LM_ERR(" AAA message len [%d] bigger then"
230                         " buffer len [%d]\n",msg_len,sourceLen);
231                 goto error;
232         }
233
234         /* command flags */
235         msg->flags = *ptr;
236         ptr += FLAGS_SIZE;
237
238         /* command code */
239         msg->commandCode = get_3bytes( ptr );
240         ptr += COMMAND_CODE_SIZE;
241
242         /* application-Id */
243         msg->applicationId = get_4bytes( ptr );
244         ptr += APPLICATION_ID_SIZE;
245
246         /* Hop-by-Hop-Id */
247         msg->hopbyhopId = *((unsigned int*)ptr);
248         ptr += HOP_BY_HOP_IDENTIFIER_SIZE;
249
250         /* End-to-End-Id */
251         msg->endtoendId = *((unsigned int*)ptr);
252         ptr += END_TO_END_IDENTIFIER_SIZE;
253
254         /* start decoding the AVPS */
255         while (ptr < source+msg_len) {
256                 if (ptr+AVP_HDR_SIZE(0x80)>source+msg_len){
257                         LM_ERR(" source buffer to short!! "
258                                 "Cannot read the whole AVP header!\n");
259                         goto error;
260                 }
261                 /* avp code */
262                 avp_code = get_4bytes( ptr );
263                 ptr += AVP_CODE_SIZE;
264                 /* avp flags */
265                 avp_flags = (unsigned char)*ptr;
266                 ptr += AVP_FLAGS_SIZE;
267                 /* avp length */
268                 avp_len = get_3bytes( ptr );
269                 ptr += AVP_LENGTH_SIZE;
270                 if (avp_len<1) {
271                         LM_ERR(" invalid AVP len [%d]\n", avp_len);
272                         goto error;
273                 }
274                 /* avp vendor-ID */
275                 avp_vendorID = 0;
276                 if (avp_flags&AAA_AVP_FLAG_VENDOR_SPECIFIC) {
277                         avp_vendorID = get_4bytes( ptr );
278                         ptr += AVP_VENDOR_ID_SIZE;
279                 }
280                 /* data length */
281                 avp_data_len = avp_len-AVP_HDR_SIZE(avp_flags);
282                 /*check the data length */
283                 if ( source+msg_len<ptr+avp_data_len) {
284                         LM_ERR(" source buffer to short!! "
285                                 "Cannot read a whole data for AVP!\n");
286                         goto error;
287                 }
288
289                 /* create the AVP */
290                 avp = AAACreateAVP( avp_code, avp_flags, avp_vendorID, (char*)ptr,
291                         avp_data_len, AVP_DONT_FREE_DATA);
292                 if (!avp)
293                         goto error;
294
295                 /* link the avp into aaa message to the end */
296                 if(AAAAddAVPToMessage(msg, avp, msg->avpList.tail)!=AAA_ERR_SUCCESS) {
297                         LM_ERR("failed to add avp to message\n");
298                 }
299
300                 ptr += to_32x_len( avp_data_len );
301         }
302
303         /* link the buffer to the message */
304         if (attach_buf) {
305                 msg->buf.s = (char*)source;
306                 msg->buf.len = msg_len;
307         }
308
309         //AAAPrintMessage( msg );
310         return  msg;
311 error:
312         LM_ERR(" message conversion droped!!\n");
313         AAAFreeMessage(&msg);
314         return 0;
315 }
316
317 /* create a new minimal AAA message */
318 AAAMessage* AAAInMessage(AAACommandCode commandCode, AAAApplicationId appId)
319 {
320         AAAMessage   *msg;
321
322         /* we allocate a new AAAMessage structure and set it to 0 */
323         msg = (AAAMessage*)ad_malloc(sizeof(AAAMessage));
324         if (!msg) 
325         {
326                 LM_ERR("no more pkg memory!\n");
327                 return NULL;
328         }
329         memset(msg, 0, sizeof(AAAMessage));
330
331         /* command code */
332         msg->commandCode = commandCode;
333
334         /* application ID */
335         msg->applicationId = appId;
336
337         /* it's a new request -> set the flag */
338         msg->flags = 0x80;
339
340         return msg;
341 }
342
343
344 /* print as debug all info contained by an aaa message + AVPs
345  */
346 void AAAPrintMessage( AAAMessage *msg)
347 {
348         char    buf[1024];
349         AAA_AVP *avp;
350
351         /* print msg info */
352         LM_DBG("AAA_MESSAGE - %p\n",msg);
353         LM_DBG("\tCode = %u\n",msg->commandCode);
354         LM_DBG("\tFlags = %x\n",msg->flags);
355
356         /*print the AVPs */
357         avp = msg->avpList.head;
358         while (avp) {
359                 AAAConvertAVPToString(avp,buf,1024);
360                 LM_DBG("\n%s\n",buf);
361                 avp=avp->next;
362         }
363 }