auth_diameter: avoid passing large structs as params and better error handling
[sip-router] / src / modules / auth_diameter / tcp_comm.c
1 /*
2  * Digest Authentication - Diameter support
3  *
4  * Copyright (C) 2001-2003 FhG Fokus
5  *
6  * This file is part of Kamailio, a free SIP server.
7  *
8  * Kamailio is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 2 of the License, or
11  * (at your option) any later version
12  * 
13  * Kamailio is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License 
19  * along with this program; if not, write to the Free Software 
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <unistd.h>
26 #include <sys/types.h>
27 #include <sys/time.h>
28 #include <sys/socket.h>
29 #include <netinet/in.h>
30 #include <netdb.h> 
31 #include <errno.h>
32
33 /* memory management */
34 #include "../../core/mem/mem.h"
35
36 /* printing messages, dealing with strings and other utils */
37 #include "../../core/dprint.h"
38 #include "../../core/str.h"
39
40 /* headers defined by this module */
41 #include "auth_diameter.h"
42 #include "defs.h"
43 #include "tcp_comm.h"
44 #include "diameter_msg.h"
45
46 #define MAX_TRIES       10
47
48 /* it initializes the TCP connection */ 
49 int init_mytcp(char* host, int port)
50 {
51         int sockfd;
52         struct sockaddr_in serv_addr;
53         struct hostent *server;
54     
55         sockfd = socket(PF_INET, SOCK_STREAM, 0);
56         
57     if (sockfd < 0) 
58         {
59                 LM_ERR("error creating the socket\n");
60                 return -1;
61         }       
62         
63     server = gethostbyname(host);
64     if (server == NULL) 
65         {
66                 LM_ERR("error finding the host\n");
67                 close(sockfd);
68                 return -1;
69     }
70
71     memset((char *) &serv_addr, 0, sizeof(serv_addr));
72     serv_addr.sin_family = PF_INET;
73     memcpy((char *)&serv_addr.sin_addr.s_addr, (char *)server->h_addr,
74                                         server->h_length);
75     serv_addr.sin_port = htons(port);
76         
77     if (connect(sockfd, (const struct sockaddr *)&serv_addr, 
78                                                         sizeof(serv_addr)) < 0) 
79         {
80         LM_ERR("error connecting to the DIAMETER client\n");
81                 close(sockfd);
82                 return -1;
83         }       
84
85         return sockfd;
86 }
87
88
89
90 void reset_read_buffer(rd_buf_t *rb)
91 {
92         rb->ret_code            = 0;
93         rb->chall_len           = 0;
94         if(rb->chall)
95                 pkg_free(rb->chall);
96         rb->chall                       = 0;
97
98         rb->first_4bytes        = 0;
99         rb->buf_len                     = 0;
100         if(rb->buf)
101                 pkg_free(rb->buf);
102         rb->buf                         = 0;
103 }
104
105 /* read from a socket, an AAA message buffer */
106 int do_read( int socket, rd_buf_t *p)
107 {
108         unsigned char  *ptr;
109         unsigned int   wanted_len, len;
110         int n;
111
112         if (p->buf==0)
113         {
114                 wanted_len = sizeof(p->first_4bytes) - p->buf_len;
115                 ptr = ((unsigned char*)&(p->first_4bytes)) + p->buf_len;
116         }
117         else
118         {
119                 wanted_len = p->first_4bytes - p->buf_len;
120                 ptr = p->buf + p->buf_len;
121         }
122
123         while( (n=recv( socket, ptr, wanted_len, MSG_DONTWAIT ))>0 ) 
124         {
125 //              LM_DBG("(sock=%d)  -> n=%d (expected=%d)\n", p->sock,n,wanted_len);
126                 p->buf_len += n;
127                 if (n<wanted_len)
128                 {
129                         //LM_DBG("only %d bytes read from %d expected\n",n,wanted_len);
130                         wanted_len -= n;
131                         ptr += n;
132                 }
133                 else 
134                 {
135                         if (p->buf==0)
136                         {
137                                 /* I just finished reading the first 4 bytes from msg */
138                                 len = ntohl(p->first_4bytes)&0x00ffffff;
139                                 if (len<AAA_MSG_HDR_SIZE || len>MAX_AAA_MSG_SIZE)
140                                 {
141                                         LM_ERR(" (sock=%d): invalid message "
142                                                 "length read %u (%x)\n", socket, len, p->first_4bytes);
143                                         goto error;
144                                 }
145                                 //LM_DBG("message length = %d(%x)\n",len,len);
146                                 if ( (p->buf=pkg_malloc(len))==0  )
147                                 {
148                                         LM_ERR("no more pkg memory\n");
149                                         goto error;
150                                 }
151                                 *((unsigned int*)p->buf) = p->first_4bytes;
152                                 p->buf_len = sizeof(p->first_4bytes);
153                                 p->first_4bytes = len;
154                                 /* update the reading position and len */
155                                 ptr = p->buf + p->buf_len;
156                                 wanted_len = p->first_4bytes - p->buf_len;
157                         }
158                         else
159                         {
160                                 /* I finished reading the whole message */
161                                 LM_DBG("(sock=%d): whole message read (len=%d)!\n",
162                                         socket, p->first_4bytes);
163                                 return CONN_SUCCESS;
164                         }
165                 }
166         }
167
168         if (n==0)
169         {
170                 LM_INFO("(sock=%d): FIN received\n", socket);
171                 return CONN_CLOSED;
172         }
173         if ( n==-1 && errno!=EINTR && errno!=EAGAIN )
174         {
175                 LM_ERR(" (sock=%d): n=%d , errno=%d (%s)\n",
176                         socket, n, errno, strerror(errno));
177                 goto error;
178         }
179 error:
180         return CONN_ERROR;
181 }
182
183
184 /* send a message over an already opened TCP connection */
185 int tcp_send_recv(int sockfd, char* buf, int len, rd_buf_t* rb, 
186                                         unsigned int waited_id)
187 {
188         int n, number_of_tries;
189         fd_set active_fd_set, read_fd_set;
190         struct timeval tv;
191         unsigned long int result_code;
192         AAAMessage *msg;
193         AAA_AVP *avp;
194         char serviceType;
195         unsigned int m_id;
196
197         /* try to write the message to the Diameter client */
198         while( (n=write(sockfd, buf, len))==-1 ) 
199         {
200                 if (errno==EINTR)
201                         continue;
202                 LM_ERR("write returned error: %s\n", strerror(errno));
203                 return AAA_ERROR;
204         }
205
206         if (n!=len) 
207         {
208                 LM_ERR("write gave no error but wrote less than asked\n");
209                 return AAA_ERROR;
210         }
211
212         /* wait for the answer a limited amount of time */
213         tv.tv_sec = MAX_WAIT_SEC;
214         tv.tv_usec = MAX_WAIT_USEC;
215
216         /* Initialize the set of active sockets. */
217         FD_ZERO (&active_fd_set);
218         FD_SET (sockfd, &active_fd_set);
219         number_of_tries = 0;
220
221         while(number_of_tries<MAX_TRIES)
222         {
223                 read_fd_set = active_fd_set;
224                 if (select (sockfd+1, &read_fd_set, NULL, NULL, &tv) < 0)
225                 {
226                         LM_ERR("select function failed\n");
227                         return AAA_ERROR;
228                 }
229 /*
230                 if (!FD_ISSET (sockfd, &read_fd_set))
231                 {
232                         LM_ERR("no response message received\n");
233 //                      return AAA_ERROR;
234                 }
235 */
236                 /* Data arriving on a already-connected socket. */
237                 reset_read_buffer(rb);
238                 switch( do_read(sockfd, rb) )
239                 {
240                         case CONN_ERROR:
241                                 LM_ERR("error when trying to read from socket\n");
242                                 return AAA_CONN_CLOSED;
243                         case CONN_CLOSED:
244                                 LM_ERR("connection closed by diameter client!\n");
245                                 return AAA_CONN_CLOSED;
246                 }
247                 
248                 /* obtain the structure corresponding to the message */
249                 msg = AAATranslateMessage(rb->buf, rb->buf_len, 0);     
250                 if(!msg)
251                 {
252                         LM_ERR("message structure not obtained\n");     
253                         return AAA_ERROR;
254                 }
255                 avp = AAAFindMatchingAVP(msg, NULL, AVP_SIP_MSGID,
256                                                                 vendorID, AAA_FORWARD_SEARCH);
257                 if(!avp)
258                 {
259                         LM_ERR("AVP_SIP_MSGID not found\n");
260                         return AAA_ERROR;
261                 }
262                 m_id = *((unsigned int*)(avp->data.s));
263                 LM_DBG("######## m_id=%d\n", m_id);
264                 if(m_id!=waited_id)
265                 {
266                         number_of_tries ++;
267                         LM_NOTICE("old message received\n");
268                         continue;
269                 }
270                 goto next;
271         }
272
273         LM_ERR("too many old messages received\n");
274         return AAA_TIMEOUT;
275 next:
276         /* Finally die correct answer */
277         avp = AAAFindMatchingAVP(msg, NULL, AVP_Service_Type,
278                                                         vendorID, AAA_FORWARD_SEARCH);
279         if(!avp)
280         {
281                 LM_ERR("AVP_Service_Type not found\n");
282                 return AAA_ERROR;
283         }
284         serviceType = avp->data.s[0];
285
286         result_code = ntohl(*((unsigned long int*)(msg->res_code->data.s)));
287         switch(result_code)
288         {
289                 case AAA_SUCCESS:                                       /* 2001 */
290                         rb->ret_code = AAA_AUTHORIZED;
291                         break;
292                 case AAA_AUTHENTICATION_REJECTED:       /* 4001 */
293                         if(serviceType!=SIP_AUTH_SERVICE)
294                         {
295                                 rb->ret_code = AAA_NOT_AUTHORIZED;
296                                 break;
297                         }
298                         avp = AAAFindMatchingAVP(msg, NULL, AVP_Challenge,
299                                                         vendorID, AAA_FORWARD_SEARCH);
300                         if(!avp)
301                         {
302                                 LM_ERR("AVP_Response not found\n");
303                                 rb->ret_code = AAA_SRVERR;
304                                 break;
305                         }
306                         rb->chall_len=avp->data.len;
307                         rb->chall = (unsigned char*)pkg_malloc(avp->data.len*sizeof(unsigned char));
308                         if(rb->chall == NULL)
309                         {
310                                 LM_ERR("no more pkg memory\n");
311                                 rb->ret_code = AAA_SRVERR;
312                                 break;
313                         }
314                         memcpy(rb->chall, avp->data.s, avp->data.len);
315                         rb->ret_code = AAA_CHALENGE;
316                         break;
317                 case AAA_AUTHORIZATION_REJECTED:        /* 5003 */
318                         rb->ret_code = AAA_NOT_AUTHORIZED;
319                         break;
320                 default:                                                        /* error */
321                         rb->ret_code = AAA_SRVERR;
322         }
323         
324     return rb->ret_code;        
325 }
326 void close_tcp_connection(int sfd)
327 {
328         shutdown(sfd, 2);
329 }
330
331