modules/ims_qos: added patch for flow-description bug when request originates from...
[sip-router] / modules / acc / diam_tcp.c
1 /*
2  * $Id$
3  *
4  * Copyright (C) 2001-2003 FhG Fokus
5  *
6  * This file is part of Kamailio, a free SIP server.
7  *
8  * Kamailio is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 2 of the License, or
11  * (at your option) any later version
12  *
13  * Kamailio is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22
23 /*! \file
24  * \ingroup acc
25  * \brief Acc:: Diameter TCP connections
26  *
27  * - Module: \ref acc
28  */
29
30 #ifdef DIAM_ACC
31
32 #include <sys/socket.h>
33 #include <netinet/in.h>
34 #include <unistd.h>
35 #include <netdb.h> 
36 #include <stdlib.h>
37 #include <string.h>
38 #include <errno.h>
39
40 #include "../../dprint.h"
41 #include "../../parser/msg_parser.h"
42 #include "../../parser/parse_to.h"
43 #include "../../parser/parse_from.h"
44 #include "../../mem/mem.h"
45
46 #include "diam_message.h"
47 #include "diam_tcp.h"
48 #include "diam_dict.h"
49
50 #define M_NAME "acc"
51
52 /*! \brief TCP connection setup */ 
53 int init_mytcp(char* host, int port)
54 {
55         int sockfd;
56         struct sockaddr_in serv_addr;
57         struct hostent *server;
58     
59         sockfd = socket(PF_INET, SOCK_STREAM, 0);
60         
61     if (sockfd < 0) 
62         {
63                 LM_ERR("failed to create the socket\n");
64                 return -1;
65         }       
66         
67     server = gethostbyname(host);
68     if (server == NULL) 
69         {
70                 LM_ERR("failed to find the host\n");
71                 return -1;
72     }
73
74     memset((char *) &serv_addr, 0, sizeof(serv_addr));
75     serv_addr.sin_family = PF_INET;
76     memcpy((char *)&serv_addr.sin_addr.s_addr, (char *)server->h_addr,
77                                         server->h_length);
78     serv_addr.sin_port = htons(port);
79         
80     if (connect(sockfd, (const struct sockaddr *)&serv_addr, 
81                                                         sizeof(serv_addr)) < 0) 
82         {
83         LM_ERR("failed to connec to the DIAMETER client\n");
84                 return -1;
85         }       
86
87         return sockfd;
88 }
89
90 /*! \brief send a message over an already opened TCP connection */
91 int tcp_send_recv(int sockfd, char* buf, int len, rd_buf_t* rb, 
92                                         unsigned int waited_id)
93 {
94         int n, number_of_tries;
95         fd_set active_fd_set, read_fd_set;
96         struct timeval tv;
97         unsigned long int result_code;
98         AAAMessage *msg;
99         AAA_AVP *avp;
100         char serviceType;
101         unsigned int m_id;
102
103         /* try to write the message to the Diameter client */
104         while( (n=write(sockfd, buf, len))==-1 ) 
105         {
106                 if (errno==EINTR)
107                         continue;
108                 LM_ERR("write returned error: %s\n", strerror(errno));
109                 return AAA_ERROR;
110         }
111
112         if (n!=len) 
113         {
114                 LM_ERR("write gave no error but wrote less than asked\n");
115                 return AAA_ERROR;
116         }
117         /* wait for the answer a limited amount of time */
118         tv.tv_sec = MAX_WAIT_SEC;
119         tv.tv_usec = MAX_WAIT_USEC;
120
121         /* Initialize the set of active sockets. */
122         FD_ZERO (&active_fd_set);
123         FD_SET (sockfd, &active_fd_set);
124         number_of_tries = 0;
125
126         while(number_of_tries<MAX_TRIES)
127         {
128                 read_fd_set = active_fd_set;
129                 if (select (sockfd+1, &read_fd_set, NULL, NULL, &tv) < 0)
130                 {
131                         LM_ERR("select function failed\n");
132                         return AAA_ERROR;
133                 }
134
135 /*              if (!FD_ISSET (sockfd, &read_fd_set))
136                 {
137                         LM_ERR("no response received\n");
138 //                      return AAA_ERROR;
139                 }
140 */              /* Data arriving on a already-connected socket. */
141                 reset_read_buffer(rb);
142                 switch( do_read(sockfd, rb) )
143                 {
144                         case CONN_ERROR:
145                                 LM_ERR("failed to read from socket\n");
146                                 return AAA_CONN_CLOSED;
147                         case CONN_CLOSED:
148                                 LM_ERR("failed to read from socket\n");
149                                 return AAA_CONN_CLOSED;
150                 }
151                 
152                 /* obtain the structure corresponding to the message */
153                 msg = AAATranslateMessage(rb->buf, rb->buf_len, 0);     
154                 if(!msg)
155                 {
156                         LM_ERR("message structure not obtained\n");     
157                         return AAA_ERROR;
158                 }
159                 avp = AAAFindMatchingAVP(msg, NULL, AVP_SIP_MSGID,
160                                                                 vendorID, AAA_FORWARD_SEARCH);
161                 if(!avp)
162                 {
163                         LM_ERR("AVP_SIP_MSGID not found\n");
164                         return AAA_ERROR;
165                 }
166                 m_id = *((unsigned int*)(avp->data.s));
167                 LM_DBG("######## m_id=%d\n", m_id);
168                 if(m_id!=waited_id)
169                 {
170                         number_of_tries ++;
171                         LM_NOTICE("old message received\n");
172                         continue;
173                 }
174                 goto next;
175         }
176
177         LM_ERR("too many old messages received\n");
178         return AAA_TIMEOUT;
179 next:
180
181         /* Finally die correct answer */
182         avp = AAAFindMatchingAVP(msg, NULL, AVP_Service_Type,
183                                                         vendorID, AAA_FORWARD_SEARCH);
184         if(!avp)
185         {
186                 LM_ERR("AVP_Service_Type not found\n");
187                 return AAA_ERROR;
188         }
189         serviceType = avp->data.s[0];
190
191         result_code = ntohl(*((unsigned long int*)(msg->res_code->data.s)));
192         switch(result_code)
193         {
194                 case AAA_SUCCESS:                                       /* 2001 */
195                         return ACC_SUCCESS;
196                 default:                                                        /* error */
197                         return ACC_FAILURE;
198         }
199 }
200
201 void reset_read_buffer(rd_buf_t *rb)
202 {
203         rb->first_4bytes        = 0;
204         rb->buf_len                     = 0;
205         if(rb->buf)
206                 pkg_free(rb->buf);
207         rb->buf                         = 0;
208 }
209
210 /*! \brief read from a socket, an AAA message buffer */
211 int do_read( int socket, rd_buf_t *p)
212 {
213         unsigned char  *ptr;
214         unsigned int   wanted_len, len;
215         int n;
216
217         if (p->buf==0)
218         {
219                 wanted_len = sizeof(p->first_4bytes) - p->buf_len;
220                 ptr = ((unsigned char*)&(p->first_4bytes)) + p->buf_len;
221         }
222         else
223         {
224                 wanted_len = p->first_4bytes - p->buf_len;
225                 ptr = p->buf + p->buf_len;
226         }
227
228         while( (n=recv( socket, ptr, wanted_len, MSG_DONTWAIT ))>0 ) 
229         {
230 //              LM_DBG("(sock=%d)  -> n=%d (expected=%d)\n", p->sock,n,wanted_len);
231                 p->buf_len += n;
232                 if (n<wanted_len)
233                 {
234                         //LM_DBG("only %d bytes read from %d expected\n",n,wanted_len);
235                         wanted_len -= n;
236                         ptr += n;
237                 }
238                 else 
239                 {
240                         if (p->buf==0)
241                         {
242                                 /* I just finished reading the first 4 bytes from msg */
243                                 len = ntohl(p->first_4bytes)&0x00ffffff;
244                                 if (len<AAA_MSG_HDR_SIZE || len>MAX_AAA_MSG_SIZE)
245                                 {
246                                         LM_ERR("(sock=%d): invalid message "
247                                                 "length read %u (%x)\n", socket, len, p->first_4bytes);
248                                         goto error;
249                                 }
250                                 //LM_DBG("message length = %d(%x)\n",len,len);
251                                 if ( (p->buf=pkg_malloc(len))==0  )
252                                 {
253                                         LM_ERR("no more pkg memory\n");
254                                         goto error;
255                                 }
256                                 *((unsigned int*)p->buf) = p->first_4bytes;
257                                 p->buf_len = sizeof(p->first_4bytes);
258                                 p->first_4bytes = len;
259                                 /* update the reading position and len */
260                                 ptr = p->buf + p->buf_len;
261                                 wanted_len = p->first_4bytes - p->buf_len;
262                         }
263                         else
264                         {
265                                 /* I finished reading the whole message */
266                                 LM_DBG("(sock=%d): whole message read (len=%d)!\n",
267                                         socket, p->first_4bytes);
268                                 return CONN_SUCCESS;
269                         }
270                 }
271         }
272
273         if (n==0)
274         {
275                 LM_INFO("(sock=%d): FIN received\n", socket);
276                 return CONN_CLOSED;
277         }
278         if ( n==-1 && errno!=EINTR && errno!=EAGAIN )
279         {
280                 LM_ERR("(on sock=%d): n=%d , errno=%d (%s)\n",
281                         socket, n, errno, strerror(errno));
282                 goto error;
283         }
284 error:
285         return CONN_ERROR;
286 }
287
288
289 void close_tcp_connection(int sfd)
290 {
291         shutdown(sfd, 2);
292 }
293
294 /*! \brief 
295  * Extract URI depending on the request from To or From header 
296  */
297 int get_uri(struct sip_msg* m, str** uri)
298 {
299         if ((REQ_LINE(m).method.len == 8) && 
300                                         (memcmp(REQ_LINE(m).method.s, "REGISTER", 8) == 0)) 
301         {/* REGISTER */
302                 if (!m->to && ((parse_headers(m, HDR_TO_F, 0) == -1) || !m->to )) 
303                 {
304                         LM_ERR("the To header field was not found or malformed\n");
305                         return -1;
306                 }
307                 *uri = &(get_to(m)->uri);
308         } 
309         else 
310         {
311                 if (parse_from_header(m)<0)
312                 {
313                         LM_ERR("failed to parse headers\n");
314                         return -2;
315                 }
316                 *uri = &(get_from(m)->uri);
317         }
318         return 0;
319 }
320
321
322 #endif