Sirikata
|
00001 /* Sirikata 00002 * SSTImpl.hpp 00003 * 00004 * Copyright (c) 2009, Tahir Azim. 00005 * All rights reserved. 00006 * 00007 * Redistribution and use in source and binary forms, with or without 00008 * modification, are permitted provided that the following conditions are 00009 * met: 00010 * * Redistributions of source code must retain the above copyright 00011 * notice, this list of conditions and the following disclaimer. 00012 * * Redistributions in binary form must reproduce the above copyright 00013 * notice, this list of conditions and the following disclaimer in 00014 * the documentation and/or other materials provided with the 00015 * distribution. 00016 * * Neither the name of Sirikata nor the names of its contributors may 00017 * be used to endorse or promote products derived from this software 00018 * without specific prior written permission. 00019 * 00020 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 00021 * IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 00022 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 00023 * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 00024 * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 00025 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 00026 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00027 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00028 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00029 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 */ 00032 00033 00034 #ifndef SST_IMPL_HPP 00035 #define SST_IMPL_HPP 00036 00037 #include <sirikata/core/util/Platform.hpp> 00038 00039 #include <sirikata/core/service/Service.hpp> 00040 #include <sirikata/core/util/Timer.hpp> 00041 #include <sirikata/core/service/Context.hpp> 00042 00043 #include <sirikata/core/network/Message.hpp> 00044 #include <sirikata/core/network/ObjectMessage.hpp> 00045 #include "Protocol_SSTHeader.pbj.hpp" 00046 00047 #include <boost/lexical_cast.hpp> 00048 #include <boost/asio.hpp> //htons, ntohs 00049 00050 #include <sirikata/core/options/CommonOptions.hpp> 00051 00052 #define SST_LOG(lvl,msg) SILOG(sst,lvl,msg); 00053 00054 namespace Sirikata { 00055 namespace SST { 00056 00057 template <typename EndObjectType> 00058 class EndPoint { 00059 public: 00060 EndObjectType endPoint; 00061 ObjectMessagePort port; 00062 00063 EndPoint() { 00064 } 00065 00066 EndPoint(EndObjectType endPoint, ObjectMessagePort port) { 00067 this->endPoint = endPoint; 00068 this->port = port; 00069 } 00070 00071 bool operator< (const EndPoint &ep) const{ 00072 if (endPoint != ep.endPoint) { 00073 return endPoint < ep.endPoint; 00074 } 00075 00076 return this->port < ep.port ; 00077 } 00078 00079 bool operator==(const EndPoint& ep) const { 00080 return ( 00081 this->port == ep.port && 00082 this->endPoint == ep.endPoint); 00083 } 00084 std::size_t hash() const { 00085 size_t seed = 0; 00086 boost::hash_combine(seed, typename EndObjectType::Hasher()(endPoint)); 00087 boost::hash_combine(seed, port); 00088 return seed; 00089 } 00090 00091 class Hasher{ 00092 public: 00093 size_t operator() (const EndPoint& ep) const { 00094 return ep.hash(); 00095 } 00096 }; 00097 00098 std::string toString() const { 00099 return endPoint.toString() + boost::lexical_cast<std::string>(port); 00100 } 00101 }; 00102 00103 class Mutex { 00104 public: 00105 00106 Mutex() { 00107 00108 } 00109 00110 Mutex(const Mutex& mutex) { } 00111 00112 boost::mutex& getMutex() { 00113 return mMutex; 00114 } 00115 00116 private: 00117 boost::mutex mMutex; 00118 00119 }; 00120 00121 template <class EndPointType> 00122 class Connection; 00123 00124 template <class EndPointType> 00125 class Stream; 00126 00127 template <typename EndPointType> 00128 class BaseDatagramLayer; 00129 00130 template <typename EndPointType> 00131 class ConnectionManager; 00132 00133 template <typename EndPointType> 00134 class CallbackTypes { 00135 public: 00136 typedef std::tr1::function< void(int, std::tr1::shared_ptr< Connection<EndPointType> > ) > ConnectionReturnCallbackFunction; 00137 typedef std::tr1::function< void(int, std::tr1::shared_ptr< Stream<EndPointType> >) > StreamReturnCallbackFunction; 00138 00139 typedef std::tr1::function< void (int, void*) > DatagramSendDoneCallback; 00140 typedef std::tr1::function<void (uint8*, int) > ReadDatagramCallback; 00141 typedef std::tr1::function<void (uint8*, int) > ReadCallback; 00142 }; 00143 00144 typedef UUID USID; 00145 00146 typedef uint32 LSID; 00147 00148 template <class EndPointType> 00149 class ConnectionVariables { 00150 public: 00151 00152 typedef std::tr1::shared_ptr<BaseDatagramLayer<EndPointType> > BaseDatagramLayerPtr; 00153 typedef CallbackTypes<EndPointType> CBTypes; 00154 typedef typename CBTypes::ConnectionReturnCallbackFunction ConnectionReturnCallbackFunction; 00155 typedef typename CBTypes::StreamReturnCallbackFunction StreamReturnCallbackFunction; 00156 00157 /* Returns 0 if no channel is available. Otherwise returns the lowest 00158 available channel. */ 00159 uint32 getAvailableChannel(EndPointType& endPointType) { 00160 BaseDatagramLayerPtr datagramLayer = getDatagramLayer(endPointType); 00161 assert (datagramLayer != BaseDatagramLayerPtr()); 00162 00163 return datagramLayer->getUnusedPort(endPointType); 00164 } 00165 00166 void releaseChannel(EndPointType& ept, uint32 channel) { 00167 00168 BaseDatagramLayerPtr datagramLayer = getDatagramLayer(ept); 00169 if (datagramLayer != BaseDatagramLayerPtr()) { 00170 00171 EndPoint<EndPointType> ep(ept, channel); 00172 00173 datagramLayer->unlisten(ep); 00174 } 00175 } 00176 00177 BaseDatagramLayerPtr getDatagramLayer(EndPointType& endPoint) 00178 { 00179 if (sDatagramLayerMap.find(endPoint) != sDatagramLayerMap.end()) { 00180 return sDatagramLayerMap[endPoint]; 00181 } 00182 00183 return BaseDatagramLayerPtr(); 00184 } 00185 00186 void addDatagramLayer(EndPointType& endPoint, BaseDatagramLayerPtr datagramLayer) 00187 { 00188 sDatagramLayerMap[endPoint] = datagramLayer; 00189 } 00190 00191 void removeDatagramLayer(EndPointType& endPoint, bool warn = false) 00192 { 00193 typename std::tr1::unordered_map<EndPointType, BaseDatagramLayerPtr, typename EndPointType::Hasher >::iterator wherei = sDatagramLayerMap.find(endPoint); 00194 if (wherei != sDatagramLayerMap.end()) { 00195 sDatagramLayerMap.erase(wherei); 00196 } else if (warn) { 00197 SILOG(sst,error,"FATAL: Invalidating BaseDatagramLayer that's invalid"); 00198 } 00199 } 00200 00201 private: 00202 std::tr1::unordered_map<EndPointType, BaseDatagramLayerPtr, typename EndPointType::Hasher > sDatagramLayerMap; 00203 00204 public: 00205 typedef std::tr1::unordered_map<EndPoint<EndPointType>, StreamReturnCallbackFunction, typename EndPoint<EndPointType>::Hasher> StreamReturnCallbackMap; 00206 StreamReturnCallbackMap mStreamReturnCallbackMap; 00207 00208 typedef std::tr1::unordered_map<EndPoint<EndPointType>, std::tr1::shared_ptr<Connection<EndPointType> >, typename EndPoint<EndPointType>::Hasher > ConnectionMap; 00209 ConnectionMap sConnectionMap; 00210 00211 typedef std::tr1::unordered_map<EndPoint<EndPointType>, ConnectionReturnCallbackFunction, typename EndPoint<EndPointType>::Hasher> ConnectionReturnCallbackMap; 00212 ConnectionReturnCallbackMap sConnectionReturnCallbackMap; 00213 00214 StreamReturnCallbackMap sListeningConnectionsCallbackMap; 00215 Mutex sStaticMembersLock; 00216 00217 }; 00218 00219 // This is just a template definition. The real implementation of BaseDatagramLayer 00220 // lies in libcore/include/sirikata/core/odp/SST.hpp and 00221 // libcore/include/sirikata/core/ohdp/SST.hpp. 00222 template <typename EndPointType> 00223 class SIRIKATA_EXPORT BaseDatagramLayer 00224 { 00225 // This class connects SST to the underlying datagram protocol. This isn't 00226 // an implementation -- the implementation will vary significantly for each 00227 // underlying datagram protocol -- but it does specify the interface. We 00228 // keep all types private in this version so it is obvious when you are 00229 // trying to incorrectly use this implementation instead of a real one. 00230 private: 00231 typedef std::tr1::shared_ptr<BaseDatagramLayer<EndPointType> > Ptr; 00232 typedef Ptr BaseDatagramLayerPtr; 00233 00234 typedef std::tr1::function<void(void*, int)> DataCallback; 00235 00244 static BaseDatagramLayerPtr createDatagramLayer( 00245 ConnectionVariables<EndPointType>* sstConnVars, 00246 EndPointType endPoint, 00247 const Context* ctx, 00248 void* extra) 00249 { 00250 return BaseDatagramLayerPtr(); 00251 } 00252 00254 static BaseDatagramLayerPtr getDatagramLayer(ConnectionVariables<EndPointType>* sstConnVars, 00255 EndPointType endPoint) 00256 { 00257 return BaseDatagramLayerPtr(); 00258 } 00259 00261 const Context* context() { 00262 return NULL; 00263 } 00264 00266 uint32 getUnusedPort(const EndPointType& ep) { 00267 return 0; 00268 } 00269 00270 00274 static void stopListening(ConnectionVariables<EndPointType>* sstConnVars, EndPoint<EndPointType>& listeningEndPoint) { 00275 } 00276 00280 void listenOn(EndPoint<EndPointType>& listeningEndPoint, DataCallback cb) { 00281 } 00282 00286 void listenOn(EndPoint<EndPointType>& listeningEndPoint) { 00287 } 00288 00293 void send(EndPoint<EndPointType>* src, EndPoint<EndPointType>* dest, void* data, int len) { 00294 } 00295 00299 void unlisten(EndPoint<EndPointType>& ep) { 00300 } 00301 00306 void invalidate() { 00307 } 00308 }; 00309 00310 #define SST_IMPL_SUCCESS 0 00311 #define SST_IMPL_FAILURE -1 00312 00313 class ChannelSegment { 00314 public: 00315 00316 uint8* mBuffer; 00317 uint16 mBufferLength; 00318 uint64 mChannelSequenceNumber; 00319 uint64 mAckSequenceNumber; 00320 00321 Time mTransmitTime; 00322 Time mAckTime; 00323 00324 ChannelSegment( const void* data, int len, uint64 channelSeqNum, uint64 ackSequenceNum) : 00325 mBufferLength(len), 00326 mChannelSequenceNumber(channelSeqNum), 00327 mAckSequenceNumber(ackSequenceNum), 00328 mTransmitTime(Time::null()), mAckTime(Time::null()) 00329 { 00330 mBuffer = new uint8[len]; 00331 memcpy( mBuffer, (const uint8*) data, len); 00332 } 00333 00334 ~ChannelSegment() { 00335 delete [] mBuffer; 00336 } 00337 00338 void setAckTime(Time& ackTime) { 00339 mAckTime = ackTime; 00340 } 00341 00342 }; 00343 00344 template <class EndPointType> 00345 class SIRIKATA_EXPORT Connection { 00346 public: 00347 typedef std::tr1::shared_ptr<Connection> Ptr; 00348 typedef Ptr ConnectionPtr; 00349 00350 private: 00351 typedef BaseDatagramLayer<EndPointType> BaseDatagramLayerType; 00352 typedef std::tr1::shared_ptr<BaseDatagramLayerType> BaseDatagramLayerPtr; 00353 00354 typedef CallbackTypes<EndPointType> CBTypes; 00355 typedef typename CBTypes::ConnectionReturnCallbackFunction ConnectionReturnCallbackFunction; 00356 typedef typename CBTypes::StreamReturnCallbackFunction StreamReturnCallbackFunction; 00357 typedef typename CBTypes::DatagramSendDoneCallback DatagramSendDoneCallback; 00358 typedef typename CBTypes::ReadDatagramCallback ReadDatagramCallback; 00359 00360 friend class Stream<EndPointType>; 00361 friend class ConnectionManager<EndPointType>; 00362 friend class BaseDatagramLayer<EndPointType>; 00363 00364 typedef std::tr1::unordered_map<EndPoint<EndPointType>, std::tr1::shared_ptr<Connection>, typename EndPoint<EndPointType>::Hasher > ConnectionMap; 00365 typedef std::tr1::unordered_map<EndPoint<EndPointType>, ConnectionReturnCallbackFunction, typename EndPoint<EndPointType>::Hasher> ConnectionReturnCallbackMap; 00366 typedef std::tr1::unordered_map<EndPoint<EndPointType>, StreamReturnCallbackFunction, typename EndPoint<EndPointType>::Hasher> StreamReturnCallbackMap; 00367 00368 EndPoint<EndPointType> mLocalEndPoint; 00369 EndPoint<EndPointType> mRemoteEndPoint; 00370 00371 ConnectionVariables<EndPointType>* mSSTConnVars; 00372 BaseDatagramLayerPtr mDatagramLayer; 00373 00374 int mState; 00375 uint32 mRemoteChannelID; 00376 uint32 mLocalChannelID; 00377 00378 uint64 mTransmitSequenceNumber; 00379 uint64 mLastReceivedSequenceNumber; //the last transmit sequence number received from the other side 00380 00381 typedef std::map<LSID, std::tr1::shared_ptr< Stream<EndPointType> > > LSIDStreamMap; 00382 std::map<LSID, std::tr1::shared_ptr< Stream<EndPointType> > > mOutgoingSubstreamMap; 00383 std::map<LSID, std::tr1::shared_ptr< Stream<EndPointType> > > mIncomingSubstreamMap; 00384 00385 std::map<uint32, StreamReturnCallbackFunction> mListeningStreamsCallbackMap; 00386 std::map<uint32, std::vector<ReadDatagramCallback> > mReadDatagramCallbacks; 00387 typedef std::vector<std::string> PartialPayloadList; 00388 typedef std::map<LSID, PartialPayloadList> PartialPayloadMap; 00389 PartialPayloadMap mPartialReadDatagrams; 00390 00391 uint32 mNumStreams; 00392 00393 std::deque< std::tr1::shared_ptr<ChannelSegment> > mQueuedSegments; 00394 std::deque< std::tr1::shared_ptr<ChannelSegment> > mOutstandingSegments; 00395 boost::mutex mOutstandingSegmentsMutex; 00396 00397 uint16 mCwnd; 00398 int64 mRTOMicroseconds; // RTO in microseconds 00399 bool mFirstRTO; 00400 00401 boost::mutex mQueueMutex; 00402 00403 uint16 MAX_DATAGRAM_SIZE; 00404 uint16 MAX_PAYLOAD_SIZE; 00405 uint32 MAX_QUEUED_SEGMENTS; 00406 float CC_ALPHA; 00407 Time mLastTransmitTime; 00408 00409 std::tr1::weak_ptr<Connection<EndPointType> > mWeakThis; 00410 00411 uint16 mNumInitialRetransmissionAttempts; 00412 00413 google::protobuf::LogSilencer logSilencer; 00414 00415 bool mInSendingMode; 00416 00417 private: 00418 00419 Connection(ConnectionVariables<EndPointType>* sstConnVars, 00420 EndPoint<EndPointType> localEndPoint, 00421 EndPoint<EndPointType> remoteEndPoint) 00422 : mLocalEndPoint(localEndPoint), mRemoteEndPoint(remoteEndPoint), 00423 mSSTConnVars(sstConnVars), 00424 mState(CONNECTION_DISCONNECTED), 00425 mRemoteChannelID(0), mLocalChannelID(1), mTransmitSequenceNumber(1), 00426 mLastReceivedSequenceNumber(1), 00427 mNumStreams(0), mCwnd(1), mRTOMicroseconds(2000000), 00428 mFirstRTO(true), MAX_DATAGRAM_SIZE(1000), MAX_PAYLOAD_SIZE(1300), 00429 MAX_QUEUED_SEGMENTS(3000), 00430 CC_ALPHA(0.8), mLastTransmitTime(Time::null()), 00431 mNumInitialRetransmissionAttempts(0), 00432 mInSendingMode(true) 00433 { 00434 mDatagramLayer = sstConnVars->getDatagramLayer(localEndPoint.endPoint); 00435 00436 mDatagramLayer->listenOn( 00437 localEndPoint, 00438 std::tr1::bind( 00439 &Connection::receiveMessage, this, 00440 std::tr1::placeholders::_1, 00441 std::tr1::placeholders::_2 00442 ) 00443 ); 00444 00445 } 00446 00447 void checkIfAlive(std::tr1::shared_ptr<Connection<EndPointType> > conn) { 00448 if (mOutgoingSubstreamMap.size() == 0 && mIncomingSubstreamMap.size() == 0) { 00449 close(true); 00450 return; 00451 } 00452 00453 getContext()->mainStrand->post(Duration::seconds(300), 00454 std::tr1::bind(&Connection<EndPointType>::checkIfAlive, this, conn), 00455 "Connection<EndPointType>::checkIfAlive" 00456 ); 00457 } 00458 00459 void sendSSTChannelPacket(Sirikata::Protocol::SST::SSTChannelHeader& sstMsg) { 00460 if (mState == CONNECTION_DISCONNECTED) return; 00461 00462 std::string buffer = serializePBJMessage(sstMsg); 00463 mDatagramLayer->send(&mLocalEndPoint, &mRemoteEndPoint, (void*) buffer.data(), 00464 buffer.size()); 00465 } 00466 00467 const Context* getContext() { 00468 return mDatagramLayer->context(); 00469 } 00470 00471 void serviceConnectionNoReturn(std::tr1::shared_ptr<Connection<EndPointType> > conn) { 00472 serviceConnection(conn); 00473 } 00474 00475 bool serviceConnection(std::tr1::shared_ptr<Connection<EndPointType> > conn) { 00476 const Time curTime = Timer::now(); 00477 00478 boost::mutex::scoped_lock lock(mOutstandingSegmentsMutex); 00479 00480 00481 if (mState == CONNECTION_PENDING_CONNECT) { 00482 mOutstandingSegments.clear(); 00483 } 00484 00485 // should start from ssthresh, the slow start lower threshold, but starting 00486 // from 1 for now. Still need to implement slow start. 00487 if (mState == CONNECTION_DISCONNECTED) { 00488 std::tr1::shared_ptr<Connection<EndPointType> > thus (mWeakThis.lock()); 00489 if (thus) { 00490 cleanup(thus); 00491 }else { 00492 SILOG(sst,error,"FATAL: disconnected lost weak pointer for Connection<EndPointType> too early to call cleanup on it"); 00493 } 00494 return false; 00495 } 00496 else if (mState == CONNECTION_PENDING_DISCONNECT) { 00497 boost::mutex::scoped_lock lock(mQueueMutex); 00498 00499 if (mQueuedSegments.empty()) { 00500 mState = CONNECTION_DISCONNECTED; 00501 std::tr1::shared_ptr<Connection<EndPointType> > thus (mWeakThis.lock()); 00502 if (thus) { 00503 cleanup(thus); 00504 }else { 00505 SILOG(sst,error,"FATAL: pending disconnection lost weak pointer for Connection<EndPointType> too early to call cleanup on it"); 00506 } 00507 return false; 00508 } 00509 } 00510 00511 if (mInSendingMode) { 00512 boost::mutex::scoped_lock lock(mQueueMutex); 00513 00514 for (int i = 0; (!mQueuedSegments.empty()) && mOutstandingSegments.size() <= mCwnd; i++) { 00515 std::tr1::shared_ptr<ChannelSegment> segment = mQueuedSegments.front(); 00516 00517 Sirikata::Protocol::SST::SSTChannelHeader sstMsg; 00518 sstMsg.set_channel_id( mRemoteChannelID ); 00519 sstMsg.set_transmit_sequence_number(segment->mChannelSequenceNumber); 00520 sstMsg.set_ack_count(1); 00521 sstMsg.set_ack_sequence_number(segment->mAckSequenceNumber); 00522 00523 sstMsg.set_payload(segment->mBuffer, segment->mBufferLength); 00524 00525 /*printf("%s sending packet from data sending loop to %s \n", 00526 mLocalEndPoint.endPoint.toString().c_str() 00527 , mRemoteEndPoint.endPoint.toString().c_str());*/ 00528 00529 00530 sendSSTChannelPacket(sstMsg); 00531 00532 if (mState == CONNECTION_PENDING_CONNECT) { 00533 mNumInitialRetransmissionAttempts++; 00534 } 00535 00536 segment->mTransmitTime = curTime; 00537 mOutstandingSegments.push_back(segment); 00538 00539 mLastTransmitTime = curTime; 00540 00541 if (mState != CONNECTION_PENDING_CONNECT || mNumInitialRetransmissionAttempts > 5) { 00542 mInSendingMode = false; 00543 mQueuedSegments.pop_front(); 00544 } 00545 } 00546 00547 if (!mInSendingMode || mState == CONNECTION_PENDING_CONNECT) { 00548 getContext()->mainStrand->post(Duration::microseconds(mRTOMicroseconds*pow(2.0,mNumInitialRetransmissionAttempts)), 00549 std::tr1::bind(&Connection<EndPointType>::serviceConnectionNoReturn, this, mWeakThis.lock()), 00550 "Connection<EndPointType>::serviceConnectionNoReturn" 00551 ); 00552 } 00553 } 00554 else { 00555 if (mState == CONNECTION_PENDING_CONNECT) { 00556 std::tr1::shared_ptr<Connection<EndPointType> > thus (mWeakThis.lock()); 00557 if (thus) { 00558 cleanup(thus); 00559 }else { 00560 SILOG(sst,error,"FATAL: pending connection lost weak pointer for Connection<EndPointType> too early to call cleanup on it"); 00561 } 00562 00563 return false; //the connection was unable to contact the other endpoint. 00564 } 00565 00566 if (mOutstandingSegments.size() > 0) { 00567 mCwnd /= 2; 00568 00569 if (mCwnd < 1) { 00570 mCwnd = 1; 00571 } 00572 00573 mOutstandingSegments.clear(); 00574 } 00575 00576 mInSendingMode = true; 00577 00578 getContext()->mainStrand->post(Duration::microseconds(1), 00579 std::tr1::bind(&Connection<EndPointType>::serviceConnectionNoReturn, this, mWeakThis.lock()), 00580 "Connection<EndPointType>::serviceConnectionNoReturn" 00581 ); 00582 } 00583 00584 return true; 00585 } 00586 00587 enum ConnectionStates { 00588 CONNECTION_DISCONNECTED = 1, // no network connectivity for this connection. 00589 // It has either never been connected or has 00590 // been fully disconnected. 00591 CONNECTION_PENDING_CONNECT = 2, // this connection is in the process of setting 00592 // up a connection. The connection setup will be 00593 // complete (or fail with an error) when the 00594 // application-specified callback is invoked. 00595 CONNECTION_PENDING_RECEIVE_CONNECT = 3,// connection received an initial 00596 // channel negotiation request, but the 00597 // negotiation has not completed yet. 00598 00599 CONNECTION_CONNECTED=4, // The connection is connected to a remote end 00600 // point. 00601 CONNECTION_PENDING_DISCONNECT=5, // The connection is in the process of 00602 // disconnecting from the remote end point. 00603 }; 00604 00605 00606 /* Create a connection for the application to a remote 00607 endpoint. The EndPoint argument specifies the location of the remote 00608 endpoint. It is templatized to enable it to refer to either IP 00609 addresses and ports, or object identifiers. The 00610 ConnectionReturnCallbackFunction returns a reference-counted, shared- 00611 pointer of the Connection that was created. The constructor may or 00612 may not actually synchronize with the remote endpoint. Instead the 00613 synchronization may be done when the first stream is created. 00614 00615 @EndPoint A templatized argument specifying the remote end-point to 00616 which this connection is connected. 00617 00618 @ConnectionReturnCallbackFunction A callback function which will be 00619 called once the connection is created and will provide a 00620 reference-counted, shared-pointer to the connection. 00621 ConnectionReturnCallbackFunction should have the signature 00622 void (std::tr1::shared_ptr<Connection>). If the std::tr1::shared_ptr argument 00623 is NULL, the connection setup failed. 00624 00625 @return false if it's not possible to create this connection, e.g. if another connection 00626 is already using the same local endpoint; true otherwise. 00627 */ 00628 00629 static bool createConnection(ConnectionVariables<EndPointType>* sstConnVars, 00630 EndPoint <EndPointType> localEndPoint, 00631 EndPoint <EndPointType> remoteEndPoint, 00632 ConnectionReturnCallbackFunction cb, 00633 StreamReturnCallbackFunction scb) 00634 00635 { 00636 boost::mutex::scoped_lock lock(sstConnVars->sStaticMembersLock.getMutex()); 00637 00638 ConnectionMap& connectionMap = sstConnVars->sConnectionMap; 00639 if (connectionMap.find(localEndPoint) != connectionMap.end()) { 00640 SST_LOG(warn, "sConnectionMap.find failed for " << localEndPoint.endPoint.toString() << "\n"); 00641 00642 return false; 00643 } 00644 00645 uint32 availableChannel = sstConnVars->getAvailableChannel(localEndPoint.endPoint); 00646 00647 if (availableChannel == 0) 00648 return false; 00649 00650 std::tr1::shared_ptr<Connection> conn = std::tr1::shared_ptr<Connection> ( 00651 new Connection(sstConnVars, localEndPoint, remoteEndPoint)); 00652 00653 connectionMap[localEndPoint] = conn; 00654 sstConnVars->sConnectionReturnCallbackMap[localEndPoint] = cb; 00655 00656 lock.unlock(); 00657 00658 conn->setWeakThis(conn); 00659 conn->setState(CONNECTION_PENDING_CONNECT); 00660 00661 uint32 payload[1]; 00662 payload[0] = htonl(availableChannel); 00663 00664 conn->setLocalChannelID(availableChannel); 00665 conn->sendData(payload, sizeof(payload), false); 00666 00667 return true; 00668 } 00669 00670 static bool listen(ConnectionVariables<EndPointType>* sstConnVars, StreamReturnCallbackFunction cb, EndPoint<EndPointType> listeningEndPoint) { 00671 sstConnVars->getDatagramLayer(listeningEndPoint.endPoint)->listenOn(listeningEndPoint); 00672 00673 boost::mutex::scoped_lock lock(sstConnVars->sStaticMembersLock.getMutex()); 00674 00675 StreamReturnCallbackMap& listeningConnectionsCallbackMap = sstConnVars->sListeningConnectionsCallbackMap; 00676 00677 if (listeningConnectionsCallbackMap.find(listeningEndPoint) != listeningConnectionsCallbackMap.end()){ 00678 return false; 00679 } 00680 00681 listeningConnectionsCallbackMap[listeningEndPoint] = cb; 00682 00683 return true; 00684 } 00685 00686 static bool unlisten(ConnectionVariables<EndPointType>* sstConnVars, EndPoint<EndPointType> listeningEndPoint) { 00687 BaseDatagramLayer<EndPointType>::stopListening(sstConnVars, listeningEndPoint); 00688 00689 boost::mutex::scoped_lock lock(sstConnVars->sStaticMembersLock.getMutex()); 00690 00691 sstConnVars->sListeningConnectionsCallbackMap.erase(listeningEndPoint); 00692 00693 return true; 00694 } 00695 00696 void listenStream(uint32 port, StreamReturnCallbackFunction scb) { 00697 mListeningStreamsCallbackMap[port] = scb; 00698 } 00699 00700 void unlistenStream(uint32 port) { 00701 mListeningStreamsCallbackMap.erase(port); 00702 } 00703 00704 /* Creates a stream on top of this connection. The function also queues 00705 up any initial data that needs to be sent on the stream. The function 00706 does not return a stream immediately since stream creation might 00707 take some time and yet fail in the end. So the function returns without 00708 synchronizing with the remote host. Instead the callback function 00709 provides a reference-counted, shared-pointer to the stream. 00710 If this connection hasn't synchronized with the remote endpoint yet, 00711 this function will also take care of doing that. 00712 00713 @data A pointer to the initial data buffer that needs to be sent on 00714 this stream. Having this pointer removes the need for the 00715 application to enqueue data until the stream is actually 00716 created. 00717 @port The length of the data buffer. 00718 @StreamReturnCallbackFunction A callback function which will be 00719 called once the stream is created and 00720 the initial data queued up (or actually 00721 sent?). The function will provide a 00722 reference counted, shared pointer to the 00723 connection. StreamReturnCallbackFunction 00724 should have the signature void (int,std::tr1::shared_ptr<Stream>). 00725 00726 @return the number of bytes queued from the initial data buffer, or -1 if there was an error. 00727 */ 00728 virtual int stream(StreamReturnCallbackFunction cb, void* initial_data, int length, 00729 uint32 local_port, uint32 remote_port) 00730 { 00731 return stream(cb, initial_data, length, local_port, remote_port, 0); 00732 } 00733 00734 virtual int stream(StreamReturnCallbackFunction cb, void* initial_data, int length, 00735 uint32 local_port, uint32 remote_port, LSID parentLSID) 00736 { 00737 USID usid = createNewUSID(); 00738 LSID lsid = ++mNumStreams; 00739 00740 std::tr1::shared_ptr<Stream<EndPointType> > stream = 00741 std::tr1::shared_ptr<Stream<EndPointType> > 00742 ( new Stream<EndPointType>(parentLSID, mWeakThis, local_port, remote_port, usid, lsid, cb, mSSTConnVars) ); 00743 stream->mWeakThis = stream; 00744 int numBytesBuffered = stream->init(initial_data, length, false, 0); 00745 00746 mOutgoingSubstreamMap[lsid]=stream; 00747 00748 return numBytesBuffered; 00749 } 00750 00751 uint64 sendData(const void* data, uint32 length, bool isAck) { 00752 boost::mutex::scoped_lock lock(mQueueMutex); 00753 00754 assert(length <= MAX_PAYLOAD_SIZE); 00755 00756 Sirikata::Protocol::SST::SSTStreamHeader* stream_msg = 00757 new Sirikata::Protocol::SST::SSTStreamHeader(); 00758 00759 std::string str = std::string( (char*)data, length); 00760 00761 bool parsed = parsePBJMessage(stream_msg, str); 00762 00763 uint64 transmitSequenceNumber = mTransmitSequenceNumber; 00764 00765 if ( isAck ) { 00766 Sirikata::Protocol::SST::SSTChannelHeader sstMsg; 00767 sstMsg.set_channel_id( mRemoteChannelID ); 00768 sstMsg.set_transmit_sequence_number(mTransmitSequenceNumber); 00769 sstMsg.set_ack_count(1); 00770 sstMsg.set_ack_sequence_number(mLastReceivedSequenceNumber); 00771 00772 sstMsg.set_payload(data, length); 00773 00774 sendSSTChannelPacket(sstMsg); 00775 } 00776 else { 00777 if (mQueuedSegments.size() < MAX_QUEUED_SEGMENTS) { 00778 mQueuedSegments.push_back( std::tr1::shared_ptr<ChannelSegment>( 00779 new ChannelSegment(data, length, mTransmitSequenceNumber, mLastReceivedSequenceNumber) ) ); 00780 00781 if (mInSendingMode) { 00782 getContext()->mainStrand->post(Duration::milliseconds(1.0), 00783 std::tr1::bind(&Connection::serviceConnectionNoReturn, this, mWeakThis.lock()), 00784 "Connection::serviceConnectionNoReturn" 00785 ); 00786 } 00787 } 00788 } 00789 00790 mTransmitSequenceNumber++; 00791 00792 delete stream_msg; 00793 00794 return transmitSequenceNumber; 00795 } 00796 00797 void setState(int state) { 00798 mState = state; 00799 } 00800 00801 uint8 getState() { 00802 return mState; 00803 } 00804 00805 void setLocalChannelID(uint32 channelID) { 00806 this->mLocalChannelID = channelID; 00807 } 00808 00809 void setRemoteChannelID(uint32 channelID) { 00810 this->mRemoteChannelID = channelID; 00811 } 00812 00813 void setWeakThis( std::tr1::shared_ptr<Connection> conn) { 00814 mWeakThis = conn; 00815 00816 getContext()->mainStrand->post(Duration::seconds(300), 00817 std::tr1::bind(&Connection<EndPointType>::checkIfAlive, this, conn), 00818 "Connection<EndPointType>::checkIfAlive" 00819 ); 00820 } 00821 00822 USID createNewUSID() { 00823 uint8 raw_uuid[UUID::static_size]; 00824 for(uint32 ui = 0; ui < UUID::static_size; ui++) 00825 raw_uuid[ui] = (uint8)rand() % 256; 00826 UUID id(raw_uuid, UUID::static_size); 00827 return id; 00828 } 00829 00830 void markAcknowledgedPacket(uint64 receivedAckNum) { 00831 boost::mutex::scoped_lock lock(mOutstandingSegmentsMutex); 00832 00833 for (std::deque< std::tr1::shared_ptr<ChannelSegment> >::iterator it = mOutstandingSegments.begin(); 00834 it != mOutstandingSegments.end(); it++) 00835 { 00836 std::tr1::shared_ptr<ChannelSegment> segment = *it; 00837 00838 if (!segment) { 00839 mOutstandingSegments.erase(it); 00840 it = mOutstandingSegments.begin(); 00841 continue; 00842 } 00843 00844 if (segment->mChannelSequenceNumber == receivedAckNum) { 00845 segment->mAckTime = Timer::now(); 00846 00847 if (mFirstRTO ) { 00848 mRTOMicroseconds = ((segment->mAckTime - segment->mTransmitTime).toMicroseconds()) ; 00849 mFirstRTO = false; 00850 } 00851 else { 00852 mRTOMicroseconds = CC_ALPHA * mRTOMicroseconds + 00853 (1.0-CC_ALPHA) * (segment->mAckTime - segment->mTransmitTime).toMicroseconds(); 00854 } 00855 00856 mInSendingMode = true; 00857 00858 std::tr1::shared_ptr<Connection<EndPointType> > conn = mWeakThis.lock(); 00859 if (conn) { 00860 getContext()->mainStrand->post( 00861 std::tr1::bind(&Connection<EndPointType>::serviceConnectionNoReturn, this, conn), 00862 "Connection<EndPointType>::serviceConnectionNoReturn" 00863 ); 00864 } 00865 00866 if (rand() % mCwnd == 0) { 00867 mCwnd += 1; 00868 } 00869 00870 mOutstandingSegments.erase(it); 00871 break; 00872 } 00873 } 00874 } 00875 00876 void parsePacket(Sirikata::Protocol::SST::SSTChannelHeader* received_channel_msg ) 00877 { 00878 Sirikata::Protocol::SST::SSTStreamHeader* received_stream_msg = 00879 new Sirikata::Protocol::SST::SSTStreamHeader(); 00880 bool parsed = parsePBJMessage(received_stream_msg, received_channel_msg->payload()); 00881 00882 if (received_stream_msg->type() == received_stream_msg->INIT) { 00883 handleInitPacket(received_stream_msg); 00884 } 00885 else if (received_stream_msg->type() == received_stream_msg->REPLY) { 00886 handleReplyPacket(received_stream_msg); 00887 } 00888 else if (received_stream_msg->type() == received_stream_msg->DATA) { 00889 handleDataPacket(received_stream_msg); 00890 } 00891 else if (received_stream_msg->type() == received_stream_msg->ACK) { 00892 handleAckPacket(received_channel_msg, received_stream_msg); 00893 } 00894 else if (received_stream_msg->type() == received_stream_msg->DATAGRAM) { 00895 handleDatagram(received_stream_msg); 00896 } 00897 00898 delete received_stream_msg ; 00899 } 00900 00901 void handleInitPacket(Sirikata::Protocol::SST::SSTStreamHeader* received_stream_msg) { 00902 LSID incomingLsid = received_stream_msg->lsid(); 00903 00904 if (mIncomingSubstreamMap.find(incomingLsid) == mIncomingSubstreamMap.end()) { 00905 if (mListeningStreamsCallbackMap.find(received_stream_msg->dest_port()) != 00906 mListeningStreamsCallbackMap.end()) 00907 { 00908 //create a new stream 00909 USID usid = createNewUSID(); 00910 LSID newLSID = ++mNumStreams; 00911 00912 std::tr1::shared_ptr<Stream<EndPointType> > stream = 00913 std::tr1::shared_ptr<Stream<EndPointType> > 00914 (new Stream<EndPointType> (received_stream_msg->psid(), mWeakThis, 00915 received_stream_msg->dest_port(), 00916 received_stream_msg->src_port(), 00917 usid, newLSID, 00918 NULL, mSSTConnVars)); 00919 stream->mWeakThis = stream; 00920 stream->init(NULL, 0, true, incomingLsid); 00921 00922 mOutgoingSubstreamMap[newLSID] = stream; 00923 mIncomingSubstreamMap[incomingLsid] = stream; 00924 00925 mListeningStreamsCallbackMap[received_stream_msg->dest_port()](0, stream); 00926 00927 stream->receiveData(received_stream_msg, received_stream_msg->payload().data(), 00928 received_stream_msg->bsn(), 00929 received_stream_msg->payload().size() ); 00930 } 00931 else { 00932 SST_LOG(warn, mLocalEndPoint.endPoint.toString() << " not listening to streams at: " << received_stream_msg->dest_port() << "\n"); 00933 } 00934 } 00935 else { 00936 mIncomingSubstreamMap[incomingLsid]->sendReplyPacket(NULL, 0, incomingLsid); 00937 } 00938 } 00939 00940 void handleReplyPacket(Sirikata::Protocol::SST::SSTStreamHeader* received_stream_msg) { 00941 LSID incomingLsid = received_stream_msg->lsid(); 00942 00943 if (mIncomingSubstreamMap.find(incomingLsid) == mIncomingSubstreamMap.end()) { 00944 LSID initiatingLSID = received_stream_msg->rsid(); 00945 00946 if (mOutgoingSubstreamMap.find(initiatingLSID) != mOutgoingSubstreamMap.end()) { 00947 std::tr1::shared_ptr< Stream<EndPointType> > stream = mOutgoingSubstreamMap[initiatingLSID]; 00948 mIncomingSubstreamMap[incomingLsid] = stream; 00949 stream->initRemoteLSID(incomingLsid); 00950 00951 if (stream->mStreamReturnCallback != NULL){ 00952 stream->mStreamReturnCallback(SST_IMPL_SUCCESS, stream); 00953 stream->mStreamReturnCallback = NULL; 00954 stream->receiveData(received_stream_msg, received_stream_msg->payload().data(), 00955 received_stream_msg->bsn(), 00956 received_stream_msg->payload().size() ); 00957 } 00958 } 00959 else { 00960 SST_LOG(detailed, "Received reply packet for unknown stream: " << initiatingLSID <<"\n"); 00961 } 00962 } 00963 } 00964 00965 void handleDataPacket(Sirikata::Protocol::SST::SSTStreamHeader* received_stream_msg) { 00966 LSID incomingLsid = received_stream_msg->lsid(); 00967 00968 if (mIncomingSubstreamMap.find(incomingLsid) != mIncomingSubstreamMap.end()) { 00969 std::tr1::shared_ptr< Stream<EndPointType> > stream_ptr = 00970 mIncomingSubstreamMap[incomingLsid]; 00971 stream_ptr->receiveData( received_stream_msg, 00972 received_stream_msg->payload().data(), 00973 received_stream_msg->bsn(), 00974 received_stream_msg->payload().size() 00975 ); 00976 } 00977 } 00978 00979 void handleAckPacket(Sirikata::Protocol::SST::SSTChannelHeader* received_channel_msg, 00980 Sirikata::Protocol::SST::SSTStreamHeader* received_stream_msg) 00981 { 00982 //printf("ACK received : offset = %d\n", (int)received_channel_msg->ack_sequence_number() ); 00983 LSID incomingLsid = received_stream_msg->lsid(); 00984 00985 if (mIncomingSubstreamMap.find(incomingLsid) != mIncomingSubstreamMap.end()) { 00986 std::tr1::shared_ptr< Stream<EndPointType> > stream_ptr = 00987 mIncomingSubstreamMap[incomingLsid]; 00988 stream_ptr->receiveData( received_stream_msg, 00989 received_stream_msg->payload().data(), 00990 received_channel_msg->ack_sequence_number(), 00991 received_stream_msg->payload().size() 00992 ); 00993 } 00994 } 00995 00996 void handleDatagram(Sirikata::Protocol::SST::SSTStreamHeader* received_stream_msg) { 00997 uint8 msg_flags = received_stream_msg->flags(); 00998 00999 if (msg_flags & Sirikata::Protocol::SST::SSTStreamHeader::CONTINUES) { 01000 // More data is coming, just store the current data 01001 mPartialReadDatagrams[received_stream_msg->lsid()].push_back( received_stream_msg->payload() ); 01002 } 01003 else { 01004 // Extract dispatch information 01005 uint32 dest_port = received_stream_msg->dest_port(); 01006 std::vector<ReadDatagramCallback> datagramCallbacks; 01007 if (mReadDatagramCallbacks.find(dest_port) != mReadDatagramCallbacks.end()) { 01008 datagramCallbacks = mReadDatagramCallbacks[dest_port]; 01009 } 01010 01011 // The datagram is all here, just deliver 01012 PartialPayloadMap::iterator it = mPartialReadDatagrams.find(received_stream_msg->lsid()); 01013 if (it != mPartialReadDatagrams.end()) { 01014 // Had previous partial packets 01015 // FIXME this should be more efficient 01016 std::string full_payload; 01017 for(PartialPayloadList::iterator pp_it = it->second.begin(); pp_it != it->second.end(); pp_it++) 01018 full_payload = full_payload + (*pp_it); 01019 full_payload = full_payload + received_stream_msg->payload(); 01020 mPartialReadDatagrams.erase(it); 01021 uint8* payload = (uint8*) full_payload.data(); 01022 uint32 payload_size = full_payload.size(); 01023 for (uint32 i=0 ; i < datagramCallbacks.size(); i++) { 01024 datagramCallbacks[i](payload, payload_size);; 01025 } 01026 } 01027 else { 01028 // Only this part, no need to aggregate into single buffer 01029 uint8* payload = (uint8*) received_stream_msg->payload().data(); 01030 uint32 payload_size = received_stream_msg->payload().size(); 01031 for (uint32 i=0 ; i < datagramCallbacks.size(); i++) { 01032 datagramCallbacks[i](payload, payload_size); 01033 } 01034 } 01035 } 01036 01037 01038 // And ack 01039 boost::mutex::scoped_lock lock(mQueueMutex); 01040 01041 Sirikata::Protocol::SST::SSTChannelHeader sstMsg; 01042 sstMsg.set_channel_id( mRemoteChannelID ); 01043 sstMsg.set_transmit_sequence_number(mTransmitSequenceNumber); 01044 sstMsg.set_ack_count(1); 01045 sstMsg.set_ack_sequence_number(mLastReceivedSequenceNumber); 01046 01047 sendSSTChannelPacket(sstMsg); 01048 01049 mTransmitSequenceNumber++; 01050 } 01051 01052 void receiveMessage(void* recv_buff, int len) { 01053 uint8* data = (uint8*) recv_buff; 01054 std::string str = std::string((char*) data, len); 01055 01056 Sirikata::Protocol::SST::SSTChannelHeader* received_msg = 01057 new Sirikata::Protocol::SST::SSTChannelHeader(); 01058 bool parsed = parsePBJMessage(received_msg, str); 01059 01060 mLastReceivedSequenceNumber = received_msg->transmit_sequence_number(); 01061 01062 uint64 receivedAckNum = received_msg->ack_sequence_number(); 01063 01064 markAcknowledgedPacket(receivedAckNum); 01065 01066 if (mState == CONNECTION_PENDING_CONNECT) { 01067 mState = CONNECTION_CONNECTED; 01068 01069 EndPoint<EndPointType> originalListeningEndPoint(mRemoteEndPoint.endPoint, mRemoteEndPoint.port); 01070 01071 uint32* received_payload = (uint32*) received_msg->payload().data(); 01072 if (received_msg->payload().size()>=sizeof(uint32)*2) { 01073 setRemoteChannelID( ntohl(received_payload[0])); 01074 mRemoteEndPoint.port = ntohl(received_payload[1]); 01075 } 01076 01077 sendData( received_payload, 0, false ); 01078 01079 boost::mutex::scoped_lock lock(mSSTConnVars->sStaticMembersLock.getMutex()); 01080 01081 ConnectionReturnCallbackMap& connectionReturnCallbackMap = mSSTConnVars->sConnectionReturnCallbackMap; 01082 ConnectionMap& connectionMap = mSSTConnVars->sConnectionMap; 01083 01084 if (connectionReturnCallbackMap.find(mLocalEndPoint) != connectionReturnCallbackMap.end()) 01085 { 01086 if (connectionMap.find(mLocalEndPoint) != connectionMap.end()) { 01087 std::tr1::shared_ptr<Connection> conn = connectionMap[mLocalEndPoint]; 01088 01089 connectionReturnCallbackMap[mLocalEndPoint] (SST_IMPL_SUCCESS, conn); 01090 } 01091 connectionReturnCallbackMap.erase(mLocalEndPoint); 01092 } 01093 } 01094 else if (mState == CONNECTION_PENDING_RECEIVE_CONNECT) { 01095 mState = CONNECTION_CONNECTED; 01096 } 01097 else if (mState == CONNECTION_CONNECTED) { 01098 if (received_msg->payload().size() > 0) { 01099 parsePacket(received_msg); 01100 } 01101 } 01102 01103 delete received_msg; 01104 } 01105 01106 uint64 getRTOMicroseconds() { 01107 return mRTOMicroseconds; 01108 } 01109 01110 void eraseDisconnectedStream(Stream<EndPointType>* s) { 01111 mOutgoingSubstreamMap.erase(s->getLSID()); 01112 mIncomingSubstreamMap.erase(s->getRemoteLSID()); 01113 01114 if (mOutgoingSubstreamMap.size() == 0 && mIncomingSubstreamMap.size() == 0) { 01115 close(true); 01116 } 01117 } 01118 01119 01120 // This is the version of cleanup is used from all the normal methods in Connection 01121 static void cleanup(std::tr1::shared_ptr<Connection<EndPointType> > conn) { 01122 conn->mDatagramLayer->unlisten(conn->mLocalEndPoint); 01123 01124 int connState = conn->mState; 01125 01126 if (connState == CONNECTION_PENDING_CONNECT || connState == CONNECTION_DISCONNECTED) { 01127 //Deal with the connection not getting connected with the remote endpoint. 01128 //This is in contrast to the case where the connection got connected, but 01129 //the connection's root stream was unable to do so. 01130 01131 boost::mutex::scoped_lock lock(conn->mSSTConnVars->sStaticMembersLock.getMutex()); 01132 ConnectionReturnCallbackFunction cb = NULL; 01133 01134 ConnectionReturnCallbackMap& connectionReturnCallbackMap = conn->mSSTConnVars->sConnectionReturnCallbackMap; 01135 if (connectionReturnCallbackMap.find(conn->localEndPoint()) != connectionReturnCallbackMap.end()) { 01136 cb = connectionReturnCallbackMap[conn->localEndPoint()]; 01137 } 01138 01139 std::tr1::shared_ptr<Connection> failed_conn = conn; 01140 01141 connectionReturnCallbackMap.erase(conn->localEndPoint()); 01142 conn->mSSTConnVars->sConnectionMap.erase(conn->localEndPoint()); 01143 01144 lock.unlock(); 01145 01146 01147 if (connState == CONNECTION_PENDING_CONNECT && cb ) { 01148 cb(SST_IMPL_FAILURE, failed_conn); 01149 } 01150 01151 conn->mState = CONNECTION_DISCONNECTED; 01152 } 01153 } 01154 01155 // This version should only be called by the destructor! 01156 void finalCleanup() { 01157 boost::mutex::scoped_lock lock(mSSTConnVars->sStaticMembersLock.getMutex()); 01158 01159 mDatagramLayer->unlisten(mLocalEndPoint); 01160 01161 if (mState != CONNECTION_DISCONNECTED) { 01162 iClose(true); 01163 mState = CONNECTION_DISCONNECTED; 01164 } 01165 01166 mSSTConnVars->releaseChannel(mLocalEndPoint.endPoint, mLocalChannelID); 01167 } 01168 01169 static void closeConnections(ConnectionVariables<EndPointType>* sstConnVars) { 01170 // We have to be careful with this function. Because it is going to free 01171 // the connections, we have to make sure not to let them get freed where 01172 // the deleter will modify sConnectionMap while we're still modifying it. 01173 // 01174 // Our approach is to just pick out the first connection, make a copy of 01175 // its shared_ptr to make sure it doesn't get freed until we want it to, 01176 // remove it from sConnectionMap, and then get rid of the shared_ptr to 01177 // allow the connection to be freed. 01178 // 01179 // Note the careful locking. Connection::~Connection will acquire the 01180 // sStaticMembersLock, so to avoid deadlocking we grab the shared_ptr, 01181 // remove it from the list and then only allow the Connection to be 01182 // destroyed after we've unlocked. 01183 while(true) { 01184 ConnectionPtr saved; 01185 { 01186 boost::mutex::scoped_lock lock(sstConnVars->sStaticMembersLock.getMutex()); 01187 if (sstConnVars->sConnectionMap.empty()) break; 01188 ConnectionMap& connectionMap = sstConnVars->sConnectionMap; 01189 01190 saved = connectionMap.begin()->second; 01191 connectionMap.erase(connectionMap.begin()); 01192 } 01193 saved.reset(); 01194 } 01195 } 01196 01197 static void handleReceive(ConnectionVariables<EndPointType>* sstConnVars, 01198 EndPoint<EndPointType> remoteEndPoint, 01199 EndPoint<EndPointType> localEndPoint, void* recv_buffer, int len) 01200 { 01201 char* data = (char*) recv_buffer; 01202 std::string str = std::string(data, len); 01203 01204 Sirikata::Protocol::SST::SSTChannelHeader* received_msg = new Sirikata::Protocol::SST::SSTChannelHeader(); 01205 bool parsed = parsePBJMessage(received_msg, str); 01206 01207 uint8 channelID = received_msg->channel_id(); 01208 01209 boost::mutex::scoped_lock lock(sstConnVars->sStaticMembersLock.getMutex()); 01210 01211 ConnectionMap& connectionMap = sstConnVars->sConnectionMap; 01212 if (connectionMap.find(localEndPoint) != connectionMap.end()) { 01213 if (channelID == 0) { 01214 /*Someone's already connected at this port. Either don't reply or 01215 send back a request rejected message. */ 01216 01217 SST_LOG(info, "Someone's already connected at this port on object " << localEndPoint.endPoint.toString() << "\n"); 01218 return; 01219 } 01220 std::tr1::shared_ptr<Connection<EndPointType> > conn = connectionMap[localEndPoint]; 01221 01222 conn->receiveMessage(data, len); 01223 } 01224 else if (channelID == 0) { 01225 /* it's a new channel request negotiation protocol 01226 packet ; allocate a new channel.*/ 01227 01228 StreamReturnCallbackMap& listeningConnectionsCallbackMap = sstConnVars->sListeningConnectionsCallbackMap; 01229 if (listeningConnectionsCallbackMap.find(localEndPoint) != listeningConnectionsCallbackMap.end()) { 01230 uint32* received_payload = (uint32*) received_msg->payload().data(); 01231 01232 uint32 payload[2]; 01233 01234 uint32 availableChannel = sstConnVars->getAvailableChannel(localEndPoint.endPoint); 01235 payload[0] = htonl(availableChannel); 01236 uint32 availablePort = availableChannel; //availableChannel is picked from the same 16-bit 01237 //address space and has to be unique. So why not use 01238 //use it to identify the port as well... 01239 payload[1] = htonl(availablePort); 01240 01241 EndPoint<EndPointType> newLocalEndPoint(localEndPoint.endPoint, availablePort); 01242 std::tr1::shared_ptr<Connection> conn = 01243 std::tr1::shared_ptr<Connection>( 01244 new Connection(sstConnVars, newLocalEndPoint, remoteEndPoint)); 01245 01246 01247 conn->listenStream(newLocalEndPoint.port, listeningConnectionsCallbackMap[localEndPoint]); 01248 conn->setWeakThis(conn); 01249 connectionMap[newLocalEndPoint] = conn; 01250 01251 conn->setLocalChannelID(availableChannel); 01252 if (received_msg->payload().size()>=sizeof(uint32)) { 01253 conn->setRemoteChannelID(ntohl(received_payload[0])); 01254 } 01255 conn->setState(CONNECTION_PENDING_RECEIVE_CONNECT); 01256 01257 conn->sendData(payload, sizeof(payload), false); 01258 } 01259 else { 01260 SST_LOG(warn, "No one listening on this connection\n"); 01261 } 01262 } 01263 01264 delete received_msg; 01265 } 01266 01267 public: 01268 01269 virtual ~Connection() { 01270 // Make sure we've fully cleaned up 01271 finalCleanup(); 01272 } 01273 01274 01275 /* Sends the specified data buffer using best-effort datagrams on the 01276 underlying connection. This may be done using an ephemeral stream 01277 on top of the underlying connection or some other mechanism (e.g. 01278 datagram packets sent directly on the underlying connection). 01279 01280 @param data the buffer to send 01281 @param length the length of the buffer 01282 @param local_port the source port 01283 @param remote_port the destination port 01284 @param DatagramSendDoneCallback a callback of type 01285 void (int errCode, void*) 01286 which is called when queuing 01287 the datagram failed or succeeded. 01288 'errCode' contains SST_IMPL_SUCCESS or SST_IMPL_FAILURE 01289 while the 'void*' argument is a pointer 01290 to the buffer that was being sent. 01291 01292 @return false if there's an immediate failure while enqueuing the datagram; true, otherwise. 01293 */ 01294 virtual bool datagram( void* data, int length, uint32 local_port, uint32 remote_port, 01295 DatagramSendDoneCallback cb) { 01296 int currOffset = 0; 01297 01298 if (mState == CONNECTION_DISCONNECTED 01299 || mState == CONNECTION_PENDING_DISCONNECT) 01300 { 01301 if (cb != NULL) { 01302 cb(SST_IMPL_FAILURE, data); 01303 } 01304 return false; 01305 } 01306 01307 LSID lsid = ++mNumStreams; 01308 01309 while (currOffset < length) { 01310 // Because the header is variable size, we have to have this 01311 // somewhat annoying logic to ensure we come in under the 01312 // budget. We start out with an extra 28 bytes as buffer. 01313 // Hopefully this is usually enough, and is based on the 01314 // current required header fields, their sizes, and overhead 01315 // from protocol buffers encoding. In the worst case, we end 01316 // up being too large and have to iterate, working with less 01317 // data over time. 01318 int header_buffer = 28; 01319 while(true) { 01320 int buffLen; 01321 bool continues; 01322 if (length-currOffset > (MAX_PAYLOAD_SIZE-header_buffer)) { 01323 buffLen = MAX_PAYLOAD_SIZE-header_buffer; 01324 continues = true; 01325 } 01326 else { 01327 buffLen = length-currOffset; 01328 continues = false; 01329 } 01330 01331 01332 Sirikata::Protocol::SST::SSTStreamHeader sstMsg; 01333 sstMsg.set_lsid( lsid ); 01334 sstMsg.set_type(sstMsg.DATAGRAM); 01335 01336 uint8 flags = 0; 01337 if (continues) 01338 flags = flags | Sirikata::Protocol::SST::SSTStreamHeader::CONTINUES; 01339 sstMsg.set_flags(flags); 01340 01341 sstMsg.set_window( (unsigned char)10 ); 01342 sstMsg.set_src_port(local_port); 01343 sstMsg.set_dest_port(remote_port); 01344 01345 sstMsg.set_payload( ((uint8*)data)+currOffset, buffLen); 01346 01347 std::string buffer = serializePBJMessage(sstMsg); 01348 01349 // If we're not within the payload size, we need to 01350 // increase our buffer space and try again 01351 if (buffer.size() > MAX_PAYLOAD_SIZE) { 01352 header_buffer += 10; 01353 continue; 01354 } 01355 01356 sendData( buffer.data(), buffer.size(), false ); 01357 01358 currOffset += buffLen; 01359 // If we got to the send, we can break out of the loop 01360 break; 01361 } 01362 } 01363 01364 if (cb != NULL) { 01365 //invoke the callback function 01366 cb(SST_IMPL_SUCCESS, data); 01367 } 01368 01369 return true; 01370 } 01371 01372 /* 01373 Register a callback which will be called when there is a datagram 01374 available to be read. 01375 01376 @param port the local port on which to listen for datagrams. 01377 @param ReadDatagramCallback a function of type "void (uint8*, int)" 01378 which will be called when a datagram is available. The 01379 "uint8*" field will be filled up with the received datagram, 01380 while the 'int' field will contain its size. 01381 @return true if the callback was successfully registered. 01382 */ 01383 virtual bool registerReadDatagramCallback(uint32 port, ReadDatagramCallback cb) { 01384 if (mReadDatagramCallbacks.find(port) == mReadDatagramCallbacks.end()) { 01385 mReadDatagramCallbacks[port] = std::vector<ReadDatagramCallback>(); 01386 } 01387 01388 mReadDatagramCallbacks[port].push_back(cb); 01389 01390 return true; 01391 } 01392 01393 /* 01394 Register a callback which will be called when there is a new 01395 datagram available to be read. In other words, datagrams we have 01396 seen previously will not trigger this callback. 01397 01398 @param ReadDatagramCallback a function of type "void (uint8*, int)" 01399 which will be called when a datagram is available. The 01400 "uint8*" field will be filled up with the received datagram, 01401 while the 'int' field will contain its size. 01402 @return true if the callback was successfully registered. 01403 */ 01404 virtual bool registerReadOrderedDatagramCallback( ReadDatagramCallback cb ) { 01405 return true; 01406 } 01407 01408 /* Closes the connection. 01409 01410 @param force if true, the connection is closed forcibly and 01411 immediately. Otherwise, the connection is closed 01412 gracefully and all outstanding packets are sent and 01413 acknowledged. Note that even in the latter case, 01414 the function returns without synchronizing with the 01415 remote end point. 01416 */ 01417 virtual void close(bool force) { 01418 boost::mutex::scoped_lock lock(mSSTConnVars->sStaticMembersLock.getMutex()); 01419 iClose(force); 01420 } 01421 01422 /* Internal, non-locking implementation of close(). 01423 Lock mSSTConnVars->sStaticMembersLock before calling this function */ 01424 virtual void iClose(bool force) { 01425 /* (mState != CONNECTION_DISCONNECTED) implies close() wasnt called 01426 through the destructor. */ 01427 if (force && mState != CONNECTION_DISCONNECTED) { 01428 mSSTConnVars->sConnectionMap.erase(mLocalEndPoint); 01429 } 01430 01431 if (force) { 01432 mState = CONNECTION_DISCONNECTED; 01433 } 01434 else { 01435 mState = CONNECTION_PENDING_DISCONNECT; 01436 } 01437 } 01438 01439 01440 01441 /* 01442 Returns the local endpoint to which this connection is bound. 01443 01444 @return the local endpoint. 01445 */ 01446 virtual EndPoint <EndPointType> localEndPoint() { 01447 return mLocalEndPoint; 01448 } 01449 01450 /* 01451 Returns the remote endpoint to which this connection is connected. 01452 01453 @return the remote endpoint. 01454 */ 01455 virtual EndPoint <EndPointType> remoteEndPoint() { 01456 return mRemoteEndPoint; 01457 } 01458 01459 }; 01460 01461 01462 class StreamBuffer{ 01463 public: 01464 01465 uint8* mBuffer; 01466 uint32 mBufferLength; 01467 uint64 mOffset; 01468 01469 Time mTransmitTime; 01470 Time mAckTime; 01471 01472 StreamBuffer(const uint8* data, uint32 len, uint64 offset) : 01473 mTransmitTime(Time::null()), mAckTime(Time::null()) 01474 { 01475 mBuffer = new uint8[len+1]; 01476 01477 if (len > 0) { 01478 memcpy(mBuffer,data,len); 01479 } 01480 01481 mBufferLength = len; 01482 mOffset = offset; 01483 } 01484 01485 ~StreamBuffer() { 01486 delete []mBuffer; 01487 } 01488 }; 01489 01490 template <class EndPointType> 01491 class SIRIKATA_EXPORT Stream { 01492 public: 01493 typedef std::tr1::shared_ptr<Stream> Ptr; 01494 typedef Ptr StreamPtr; 01495 typedef Connection<EndPointType> ConnectionType; 01496 typedef EndPoint<EndPointType> EndpointType; 01497 01498 typedef CallbackTypes<EndPointType> CBTypes; 01499 typedef typename CBTypes::StreamReturnCallbackFunction StreamReturnCallbackFunction; 01500 typedef typename CBTypes::ReadCallback ReadCallback; 01501 01502 typedef std::tr1::unordered_map<EndPoint<EndPointType>, StreamReturnCallbackFunction, typename EndPoint<EndPointType>::Hasher> StreamReturnCallbackMap; 01503 01504 enum StreamStates { 01505 DISCONNECTED = 1, 01506 CONNECTED=2, 01507 PENDING_DISCONNECT=3, 01508 PENDING_CONNECT=4, 01509 NOT_FINISHED_CONSTRUCTING__CALL_INIT 01510 }; 01511 01512 01513 01514 virtual ~Stream() { 01515 close(true); 01516 01517 delete [] mInitialData; 01518 delete [] mReceiveBuffer; 01519 delete [] mReceiveBitmap; 01520 01521 mConnection.reset(); 01522 } 01523 01524 bool connected() { return mConnected; } 01525 01526 static bool connectStream(ConnectionVariables<EndPointType>* sstConnVars, 01527 EndPoint <EndPointType> localEndPoint, 01528 EndPoint <EndPointType> remoteEndPoint, 01529 StreamReturnCallbackFunction cb) 01530 { 01531 if (localEndPoint.port == 0) { 01532 typename BaseDatagramLayer<EndPointType>::Ptr bdl = sstConnVars->getDatagramLayer(localEndPoint.endPoint); 01533 if (!bdl) { 01534 SST_LOG(error,"Tried to connect stream without calling createDatagramLayer for the endpoint."); 01535 return false; 01536 } 01537 localEndPoint.port = bdl->getUnusedPort(localEndPoint.endPoint); 01538 } 01539 01540 StreamReturnCallbackMap& streamReturnCallbackMap = sstConnVars->mStreamReturnCallbackMap; 01541 if (streamReturnCallbackMap.find(localEndPoint) != streamReturnCallbackMap.end()) { 01542 return false; 01543 } 01544 01545 streamReturnCallbackMap[localEndPoint] = cb; 01546 01547 bool result = Connection<EndPointType>::createConnection(sstConnVars, 01548 localEndPoint, 01549 remoteEndPoint, 01550 connectionCreated, cb); 01551 return result; 01552 } 01553 01554 /* 01555 Start listening for top-level streams on the specified end-point. When 01556 a new top-level stream connects at the given endpoint, the specified 01557 callback function is invoked handing the object a top-level stream. 01558 @param cb the callback function invoked when a new stream is created 01559 @param listeningEndPoint the endpoint where SST will accept new incoming 01560 streams. 01561 @return false, if its not possible to listen to this endpoint (e.g. if listen 01562 has already been called on this endpoint); true otherwise. 01563 */ 01564 static bool listen(ConnectionVariables<EndPointType>* sstConnVars, StreamReturnCallbackFunction cb, EndPoint <EndPointType> listeningEndPoint) { 01565 return Connection<EndPointType>::listen(sstConnVars, cb, listeningEndPoint); 01566 } 01567 01568 static bool unlisten(ConnectionVariables<EndPointType>* sstConnVars, EndPoint <EndPointType> listeningEndPoint) { 01569 return Connection<EndPointType>::unlisten(sstConnVars, listeningEndPoint); 01570 } 01571 01572 /* 01573 Start listening for child streams on the specified port. A remote stream 01574 can only create child streams under this stream if this stream is listening 01575 on the port specified for the child stream. 01576 01577 @param scb the callback function invoked when a new stream is created 01578 @param port the endpoint where SST will accept new incoming 01579 streams. 01580 */ 01581 void listenSubstream(uint32 port, StreamReturnCallbackFunction scb) { 01582 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01583 if (!conn) { 01584 scb(SST_IMPL_FAILURE, StreamPtr() ); 01585 return; 01586 } 01587 01588 conn->listenStream(port, scb); 01589 } 01590 01591 void unlistenSubstream(uint32 port) { 01592 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01593 01594 if (conn) 01595 conn->unlistenStream(port); 01596 } 01597 01598 /* Writes data bytes to the stream. If not all bytes can be transmitted 01599 immediately, they are queued locally until ready to transmit. 01600 @param data the buffer containing the bytes to be written 01601 @param len the length of the buffer 01602 @return the number of bytes written or enqueued, or -1 if an error 01603 occurred 01604 */ 01605 virtual int write(const uint8* data, int len) { 01606 if (mState == DISCONNECTED || mState == PENDING_DISCONNECT) { 01607 return -1; 01608 } 01609 01610 boost::mutex::scoped_lock lock(mQueueMutex); 01611 int count = 0; 01612 01613 if (len <= MAX_PAYLOAD_SIZE) { 01614 if (mCurrentQueueLength+len > MAX_QUEUE_LENGTH) { 01615 return 0; 01616 } 01617 mQueuedBuffers.push_back( std::tr1::shared_ptr<StreamBuffer>(new StreamBuffer(data, len, mNumBytesSent)) ); 01618 mCurrentQueueLength += len; 01619 mNumBytesSent += len; 01620 01621 01622 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01623 if (conn) 01624 getContext()->mainStrand->post(Duration::seconds(0.01), 01625 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 01626 "Stream<EndPointType>::serviceStreamNoReturn" 01627 ); 01628 01629 return len; 01630 } 01631 else { 01632 int currOffset = 0; 01633 while (currOffset < len) { 01634 int buffLen = (len-currOffset > MAX_PAYLOAD_SIZE) ? 01635 MAX_PAYLOAD_SIZE : 01636 (len-currOffset); 01637 01638 if (mCurrentQueueLength + buffLen > MAX_QUEUE_LENGTH) { 01639 break; 01640 } 01641 01642 mQueuedBuffers.push_back( std::tr1::shared_ptr<StreamBuffer>(new StreamBuffer(data+currOffset, buffLen, mNumBytesSent)) ); 01643 currOffset += buffLen; 01644 mCurrentQueueLength += buffLen; 01645 mNumBytesSent += buffLen; 01646 01647 count++; 01648 } 01649 01650 01651 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01652 if (conn) 01653 getContext()->mainStrand->post(Duration::seconds(0.01), 01654 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 01655 "Stream<EndPointType>::serviceStreamNoReturn" 01656 ); 01657 01658 return currOffset; 01659 } 01660 01661 return -1; 01662 } 01663 01664 #if SIRIKATA_PLATFORM != SIRIKATA_PLATFORM_WINDOWS 01665 /* Gathers data from the buffers described in 'vec', 01666 which is taken to be 'count' structures long, and 01667 writes them to the stream. As each buffer is 01668 written, it moves on to the next. If not all bytes 01669 can be transmitted immediately, they are queued 01670 locally until ready to transmit. 01671 01672 The return value is a count of bytes written. 01673 01674 @param vec the array containing the iovec buffers to be written 01675 @param count the number of iovec buffers in the array 01676 @return the number of bytes written or enqueued, or -1 if an error 01677 occurred 01678 */ 01679 virtual int writev(const struct iovec* vec, int count) { 01680 int totalBytesWritten = 0; 01681 01682 for (int i=0; i < count; i++) { 01683 int numWritten = write( (const uint8*) vec[i].iov_base, vec[i].iov_len); 01684 01685 if (numWritten < 0) return -1; 01686 01687 totalBytesWritten += numWritten; 01688 01689 if (numWritten == 0) { 01690 return totalBytesWritten; 01691 } 01692 } 01693 01694 return totalBytesWritten; 01695 } 01696 #endif 01697 01698 /* 01699 Register a callback which will be called when there are bytes to be 01700 read from the stream. 01701 01702 @param ReadCallback a function of type "void (uint8*, int)" which will 01703 be called when data is available. The "uint8*" field will be filled 01704 up with the received data, while the 'int' field will contain 01705 the size of the data. 01706 @return true if the callback was successfully registered. 01707 */ 01708 virtual bool registerReadCallback( ReadCallback callback) { 01709 mReadCallback = callback; 01710 01711 boost::recursive_mutex::scoped_lock lock(mReceiveBufferMutex); 01712 sendToApp(0); 01713 01714 return true; 01715 } 01716 01717 /* Close this stream. If the 'force' parameter is 'false', 01718 all outstanding data is sent and acknowledged before the stream is closed. 01719 Otherwise, the stream is closed immediately and outstanding data may be lost. 01720 Note that in the former case, the function will still return immediately, changing 01721 the state of the connection PENDING_DISCONNECT without necessarily talking to the 01722 remote endpoint. 01723 @param force use false if the stream should be gracefully closed, true otherwise. 01724 @return true if the stream was successfully closed. 01725 01726 */ 01727 virtual bool close(bool force) { 01728 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01729 if (force) { 01730 mConnected = false; 01731 mState = DISCONNECTED; 01732 01733 if (conn) { 01734 conn->eraseDisconnectedStream(this); 01735 } 01736 01737 return true; 01738 } 01739 else { 01740 mState = PENDING_DISCONNECT; 01741 if (conn) { 01742 getContext()->mainStrand->post( 01743 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 01744 "Stream<EndPointType>::serviceStreamNoReturn" 01745 ); 01746 } 01747 return true; 01748 } 01749 } 01750 01751 /* 01752 Sets the priority of this stream. 01753 As in the original SST interface, this implementation gives strict preference to 01754 streams with higher priority over streams with lower priority, but it divides 01755 available transmit bandwidth evenly among streams with the same priority level. 01756 All streams have a default priority level of zero. 01757 @param the new priority level of the stream. 01758 */ 01759 virtual void setPriority(int pri) { 01760 01761 } 01762 01763 /*Returns the stream's current priority level. 01764 @return the stream's current priority level 01765 */ 01766 virtual int priority() { 01767 return 0; 01768 } 01769 01770 /* Returns the top-level connection that created this stream. 01771 @return a pointer to the connection that created this stream. 01772 */ 01773 virtual std::tr1::weak_ptr<Connection<EndPointType> > connection() { 01774 return mConnection; 01775 } 01776 01777 /* Creates a child stream. The function also queues up 01778 any initial data that needs to be sent on the child stream. The function does not 01779 return a stream immediately since stream creation might take some time and 01780 yet fail in the end. So the function returns without synchronizing with the 01781 remote host. Instead the callback function provides a reference-counted, 01782 shared-pointer to the stream. If this connection hasn't synchronized with 01783 the remote endpoint yet, this function will also take care of doing that. 01784 01785 @param data A pointer to the initial data buffer that needs to be sent on this stream. 01786 Having this pointer removes the need for the application to enqueue data 01787 until the stream is actually created. 01788 @param port The length of the data buffer. 01789 @param local_port the local port to which the child stream will be bound. 01790 @param remote_port the remote port to which the child stream should connect. 01791 @param StreamReturnCallbackFunction A callback function which will be called once the 01792 stream is created and the initial data queued up 01793 (or actually sent?). The function will provide a 01794 reference counted, shared pointer to the connection. 01795 01796 @return the number of bytes actually buffered from the initial data buffer specified, or 01797 -1 if an error occurred. 01798 */ 01799 virtual int createChildStream(StreamReturnCallbackFunction cb, void* data, int length, 01800 uint32 local_port, uint32 remote_port) 01801 { 01802 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01803 if (conn) { 01804 return conn->stream(cb, data, length, local_port, remote_port, mLSID); 01805 } 01806 01807 return -1; 01808 } 01809 01810 /* 01811 Returns the local endpoint to which this connection is bound. 01812 01813 @return the local endpoint. 01814 */ 01815 virtual EndPoint <EndPointType> localEndPoint() { 01816 return mLocalEndPoint; 01817 } 01818 01819 /* 01820 Returns the remote endpoint to which this connection is bound. 01821 01822 @return the remote endpoint. 01823 */ 01824 virtual EndPoint <EndPointType> remoteEndPoint() { 01825 return mRemoteEndPoint; 01826 } 01827 01828 virtual uint8 getState() { 01829 return mState; 01830 } 01831 01832 const Context* getContext() { 01833 if (mContext == NULL) { 01834 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01835 assert(conn); 01836 01837 mContext = conn->getContext(); 01838 } 01839 01840 return mContext; 01841 } 01842 01843 private: 01844 Stream(LSID parentLSID, std::tr1::weak_ptr<Connection<EndPointType> > conn, 01845 uint32 local_port, uint32 remote_port, 01846 USID usid, LSID lsid, StreamReturnCallbackFunction cb, ConnectionVariables<EndPointType>* sstConnVars) 01847 : 01848 mState(NOT_FINISHED_CONSTRUCTING__CALL_INIT), 01849 mLocalPort(local_port), 01850 mRemotePort(remote_port), 01851 mParentLSID(parentLSID), 01852 mConnection(conn),mContext(NULL), 01853 mUSID(usid), 01854 mLSID(lsid), 01855 mRemoteLSID(-1), 01856 MAX_PAYLOAD_SIZE(1000), 01857 MAX_QUEUE_LENGTH(4000000), 01858 MAX_RECEIVE_WINDOW(GetOptionValue<uint32>(OPT_SST_DEFAULT_WINDOW_SIZE)), 01859 mFirstRTO(true), 01860 mStreamRTOMicroseconds(2000000), 01861 FL_ALPHA(0.8), 01862 mTransmitWindowSize(MAX_RECEIVE_WINDOW), 01863 mReceiveWindowSize(MAX_RECEIVE_WINDOW), 01864 mNumOutstandingBytes(0), 01865 mNextByteExpected(0), 01866 mLastContiguousByteReceived(-1), 01867 mLastSendTime(Time::null()), 01868 mLastReceiveTime(Time::null()), 01869 mStreamReturnCallback(cb), 01870 mConnected (false), 01871 MAX_INIT_RETRANSMISSIONS(5), 01872 mSSTConnVars(sstConnVars) 01873 { 01874 mInitialData = NULL; 01875 mInitialDataLength = 0; 01876 01877 mReceiveBuffer = NULL; 01878 mReceiveBitmap = NULL; 01879 01880 mQueuedBuffers.clear(); 01881 mCurrentQueueLength = 0; 01882 01883 std::tr1::shared_ptr<Connection<EndPointType> > locked_conn = mConnection.lock(); 01884 mRemoteEndPoint = EndPoint<EndPointType> (locked_conn->remoteEndPoint().endPoint, mRemotePort); 01885 mLocalEndPoint = EndPoint<EndPointType> (locked_conn->localEndPoint().endPoint, mLocalPort); 01886 01887 // Continues in init, when we have mWeakThis set 01888 } 01889 01890 int init(void* initial_data, uint32 length, bool remotelyInitiated, LSID remoteLSID) { 01891 mNumInitRetransmissions = 1; 01892 if (remotelyInitiated) { 01893 mRemoteLSID = remoteLSID; 01894 mConnected = true; 01895 mState = CONNECTED; 01896 } 01897 else { 01898 mConnected = false; 01899 mState = PENDING_CONNECT; 01900 } 01901 01902 mInitialDataLength = (length <= MAX_PAYLOAD_SIZE) ? length : MAX_PAYLOAD_SIZE; 01903 int numBytesBuffered = mInitialDataLength; 01904 01905 if (initial_data != NULL) { 01906 mInitialData = new uint8[mInitialDataLength]; 01907 01908 memcpy(mInitialData, initial_data, mInitialDataLength); 01909 } 01910 else { 01911 mInitialData = new uint8[1]; 01912 mInitialDataLength = 0; 01913 } 01914 01915 if (remotelyInitiated) { 01916 sendReplyPacket(mInitialData, mInitialDataLength, remoteLSID); 01917 } 01918 else { 01919 sendInitPacket(mInitialData, mInitialDataLength); 01920 } 01921 01922 mNumBytesSent = mInitialDataLength; 01923 01924 if (length > mInitialDataLength) { 01925 int writeval = write( ((uint8*)initial_data) + mInitialDataLength, length - mInitialDataLength); 01926 01927 if (writeval >= 0) { 01928 numBytesBuffered += writeval; 01929 } 01930 } 01931 01933 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 01934 if (conn) { 01935 getContext()->mainStrand->post(Duration::seconds(60), 01936 std::tr1::bind(&Stream<EndPointType>::sendKeepAlive, this, mWeakThis, conn), 01937 "Stream<EndPointType>::sendKeepAlive" 01938 ); 01939 } 01940 01941 return numBytesBuffered; 01942 } 01943 01944 uint8* receiveBuffer() { 01945 if (mReceiveBuffer == NULL) 01946 mReceiveBuffer = new uint8[MAX_RECEIVE_WINDOW]; 01947 return mReceiveBuffer; 01948 } 01949 01950 uint8* receiveBitmap() { 01951 if (mReceiveBitmap == NULL) { 01952 mReceiveBitmap = new uint8[MAX_RECEIVE_WINDOW]; 01953 memset(mReceiveBitmap, 0, MAX_RECEIVE_WINDOW); 01954 } 01955 return mReceiveBitmap; 01956 } 01957 01958 void initRemoteLSID(LSID remoteLSID) { 01959 mRemoteLSID = remoteLSID; 01960 } 01961 01962 void sendKeepAlive(std::tr1::weak_ptr<Stream<EndPointType> > wstrm, std::tr1::shared_ptr<Connection<EndPointType> > conn) { 01963 std::tr1::shared_ptr<Stream<EndPointType> > strm = wstrm.lock(); 01964 if (!strm) return; 01965 01966 if (mState == DISCONNECTED || mState == PENDING_DISCONNECT) { 01967 close(true); 01968 return; 01969 } 01970 01971 uint8 buf[1]; 01972 01973 write(buf, 0); 01974 01975 getContext()->mainStrand->post(Duration::seconds(60), 01976 std::tr1::bind(&Stream<EndPointType>::sendKeepAlive, this, wstrm, conn), 01977 "Stream<EndPointType>::sendKeepAlive" 01978 ); 01979 } 01980 01981 static void connectionCreated( int errCode, std::tr1::shared_ptr<Connection<EndPointType> > c) { 01982 StreamReturnCallbackMap& streamReturnCallbackMap = c->mSSTConnVars->mStreamReturnCallbackMap; 01983 assert(streamReturnCallbackMap.find(c->localEndPoint()) != streamReturnCallbackMap.end()); 01984 01985 if (errCode != SST_IMPL_SUCCESS) { 01986 01987 StreamReturnCallbackFunction cb = streamReturnCallbackMap[c->localEndPoint()]; 01988 streamReturnCallbackMap.erase(c->localEndPoint()); 01989 01990 cb(SST_IMPL_FAILURE, StreamPtr() ); 01991 01992 return; 01993 } 01994 01995 c->stream(streamReturnCallbackMap[c->localEndPoint()], NULL , 0, 01996 c->localEndPoint().port, c->remoteEndPoint().port); 01997 01998 streamReturnCallbackMap.erase(c->localEndPoint()); 01999 } 02000 02001 void serviceStreamNoReturn(std::tr1::shared_ptr<Stream<EndPointType> > strm, std::tr1::shared_ptr<Connection<EndPointType> > conn) { 02002 serviceStream(strm, conn); 02003 } 02004 02005 /* Returns false only if this is the root stream of a connection and it was 02006 unable to connect. In that case, the connection for this stream needs to 02007 be closed and the 'false' return value is an indication of this for 02008 the underlying connection. */ 02009 02010 bool serviceStream(std::tr1::shared_ptr<Stream<EndPointType> > strm, std::tr1::shared_ptr<Connection<EndPointType> > conn) { 02011 assert(strm.get() == this); 02012 02013 const Time curTime = Timer::now(); 02014 02015 if ( (curTime - mLastReceiveTime).toSeconds() > 300 && mLastReceiveTime != Time::null()) 02016 { 02017 close(true); 02018 return true; 02019 } 02020 02021 if (mState != CONNECTED && mState != DISCONNECTED && mState != PENDING_DISCONNECT) { 02022 02023 if (!mConnected && mNumInitRetransmissions < MAX_INIT_RETRANSMISSIONS ) { 02024 02025 sendInitPacket(mInitialData, mInitialDataLength); 02026 02027 mLastSendTime = curTime; 02028 02029 mNumInitRetransmissions++; 02030 02031 return true; 02032 } 02033 02034 mInitialDataLength = 0; 02035 02036 if (!mConnected) { 02037 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02038 assert(conn); 02039 02040 mSSTConnVars->mStreamReturnCallbackMap.erase(conn->localEndPoint()); 02041 02042 // If this is the root stream that failed to connect, close the 02043 // connection associated with it as well. 02044 if (mParentLSID == 0) { 02045 conn->close(true); 02046 02047 Connection<EndPointType>::cleanup(conn); 02048 } 02049 02050 //send back an error to the app by calling mStreamReturnCallback 02051 //with an error code. 02052 if (mStreamReturnCallback) { 02053 mStreamReturnCallback(SST_IMPL_FAILURE, StreamPtr()); 02054 mStreamReturnCallback = NULL; 02055 } 02056 02057 02058 conn->eraseDisconnectedStream(this); 02059 mState = DISCONNECTED; 02060 02061 return false; 02062 } 02063 else { 02064 mState = CONNECTED; 02065 // Schedule another servicing immediately in case any other operations 02066 // should occur, e.g. sending data which was added after the initial 02067 // connection request. 02068 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02069 if (conn) 02070 getContext()->mainStrand->post( 02071 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 02072 "Stream<EndPointType>::serviceStreamNoReturn" 02073 ); 02074 } 02075 } 02076 else { 02077 if (mState != DISCONNECTED) { 02078 02079 //if the stream has been waiting for an ACK for > 2*mStreamRTOMicroseconds, 02080 //resend the unacked packets. We don't actually check if we 02081 //have anything to ack here, that happens in resendUnackedPackets. Also, 02082 //'resending' really just means sticking them back at the front of 02083 //mQueuedBuffers, so the code that follows and actually sends data will 02084 //ensure that we trigger a re-servicing sometime in the future. 02085 if ( mLastSendTime != Time::null() 02086 && (curTime - mLastSendTime).toMicroseconds() > 2*mStreamRTOMicroseconds) 02087 { 02088 resendUnackedPackets(); 02089 mLastSendTime = curTime; 02090 } 02091 02092 boost::mutex::scoped_lock lock(mQueueMutex); 02093 02094 if (mState == PENDING_DISCONNECT && 02095 mQueuedBuffers.empty() && 02096 mChannelToBufferMap.empty() ) 02097 { 02098 mState = DISCONNECTED; 02099 02100 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02101 assert(conn); 02102 02103 conn->eraseDisconnectedStream(this); 02104 02105 return true; 02106 } 02107 02108 bool sentSomething = false; 02109 while ( !mQueuedBuffers.empty() ) { 02110 std::tr1::shared_ptr<StreamBuffer> buffer = mQueuedBuffers.front(); 02111 02112 if (mTransmitWindowSize < buffer->mBufferLength) { 02113 break; 02114 } 02115 02116 uint64 channelID = sendDataPacket(buffer->mBuffer, 02117 buffer->mBufferLength, 02118 buffer->mOffset 02119 ); 02120 buffer->mTransmitTime = curTime; 02121 sentSomething = true; 02122 02123 if ( mChannelToBufferMap.find(channelID) == mChannelToBufferMap.end() ) { 02124 mChannelToBufferMap[channelID] = buffer; 02125 mChannelToStreamOffsetMap[channelID] = buffer->mOffset; 02126 } 02127 02128 mQueuedBuffers.pop_front(); 02129 mCurrentQueueLength -= buffer->mBufferLength; 02130 mLastSendTime = curTime; 02131 02132 assert(buffer->mBufferLength <= mTransmitWindowSize); 02133 mTransmitWindowSize -= buffer->mBufferLength; 02134 mNumOutstandingBytes += buffer->mBufferLength; 02135 } 02136 02137 if (sentSomething) { 02138 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02139 if (conn) 02140 getContext()->mainStrand->post(Duration::microseconds(2*mStreamRTOMicroseconds), 02141 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 02142 "Stream<EndPointType>::serviceStreamNoReturn" 02143 ); 02144 } 02145 } 02146 } 02147 02148 return true; 02149 } 02150 02151 inline void resendUnackedPackets() { 02152 boost::mutex::scoped_lock lock(mQueueMutex); 02153 02154 for(std::map<uint64,std::tr1::shared_ptr<StreamBuffer> >::const_reverse_iterator it=mChannelToBufferMap.rbegin(), 02155 it_end=mChannelToBufferMap.rend(); 02156 it != it_end; it++) 02157 { 02158 mQueuedBuffers.push_front(it->second); 02159 mCurrentQueueLength += it->second->mBufferLength; 02160 02161 /*printf("On %d, resending unacked packet at offset %d:%d\n", 02162 (int)mLSID, (int)it->first, (int)(it->second->mOffset));fflush(stdout);*/ 02163 02164 if (mTransmitWindowSize < it->second->mBufferLength){ 02165 assert( ((int) it->second->mBufferLength) > 0); 02166 mTransmitWindowSize = it->second->mBufferLength; 02167 } 02168 } 02169 02170 if (mChannelToBufferMap.empty() && !mQueuedBuffers.empty()) { 02171 std::tr1::shared_ptr<StreamBuffer> buffer = mQueuedBuffers.front(); 02172 02173 if (mTransmitWindowSize < buffer->mBufferLength) { 02174 mTransmitWindowSize = buffer->mBufferLength; 02175 } 02176 } 02177 02178 mNumOutstandingBytes = 0; 02179 02180 if (!mChannelToBufferMap.empty()) { 02181 if (mStreamRTOMicroseconds < 20000000) { 02182 mStreamRTOMicroseconds *= 2; 02183 } 02184 mChannelToBufferMap.clear(); 02185 } 02186 } 02187 02188 /* This function sends received data up to the application interface. 02189 mReceiveBufferMutex must be locked before calling this function. */ 02190 void sendToApp(uint32 skipLength) { 02191 // Special case: if we're not marking any data as skipped and we 02192 // haven't allocated the receive bitmap yet, then we're not 02193 // going to send anything anyway. Just ignore this call. 02194 if (mReceiveBitmap == NULL && skipLength == 0) 02195 return; 02196 02197 uint32 readyBufferSize = skipLength; 02198 uint8* recv_bmap = receiveBitmap(); 02199 for (uint32 i=skipLength; i < MAX_RECEIVE_WINDOW; i++) { 02200 if (recv_bmap[i] == 1) { 02201 readyBufferSize++; 02202 } 02203 else if (recv_bmap[i] == 0) { 02204 break; 02205 } 02206 } 02207 02208 //pass data up to the app from 0 to readyBufferSize; 02209 // 02210 if (mReadCallback != NULL && readyBufferSize > 0) { 02211 uint8* recv_buf = receiveBuffer(); 02212 mReadCallback(recv_buf, readyBufferSize); 02213 02214 //now move the window forward... 02215 mLastContiguousByteReceived = mLastContiguousByteReceived + readyBufferSize; 02216 mNextByteExpected = mLastContiguousByteReceived + 1; 02217 02218 uint8* recv_bmap = receiveBitmap(); 02219 memmove(recv_bmap, recv_bmap + readyBufferSize, MAX_RECEIVE_WINDOW - readyBufferSize); 02220 memset(recv_bmap + (MAX_RECEIVE_WINDOW - readyBufferSize), 0, readyBufferSize); 02221 02222 memmove(recv_buf, recv_buf + readyBufferSize, MAX_RECEIVE_WINDOW - readyBufferSize); 02223 02224 mReceiveWindowSize += readyBufferSize; 02225 } 02226 } 02227 02228 void receiveData( Sirikata::Protocol::SST::SSTStreamHeader* streamMsg, 02229 const void* buffer, uint64 offset, uint32 len ) 02230 { 02231 const Time curTime = Timer::now(); 02232 mLastReceiveTime = curTime; 02233 02234 if (streamMsg->type() == streamMsg->REPLY) { 02235 mConnected = true; 02236 } 02237 else if (streamMsg->type() == streamMsg->DATA || streamMsg->type() == streamMsg->INIT) { 02238 boost::recursive_mutex::scoped_lock lock(mReceiveBufferMutex); 02239 02240 int transmitWindowSize = pow(2.0, streamMsg->window()) - mNumOutstandingBytes; 02241 if (transmitWindowSize >= 0) { 02242 mTransmitWindowSize = transmitWindowSize; 02243 } 02244 else { 02245 mTransmitWindowSize = 0; 02246 } 02247 02248 02249 /*std::cout << "offset=" << offset << " , mLastContiguousByteReceived=" << mLastContiguousByteReceived 02250 << " , mNextByteExpected=" << mNextByteExpected <<"\n";*/ 02251 02252 int64 offsetInBuffer = offset - mLastContiguousByteReceived - 1; 02253 if ( len > 0 && (int64)(offset) == mNextByteExpected) { 02254 if (offsetInBuffer + len <= MAX_RECEIVE_WINDOW) { 02255 mReceiveWindowSize -= len; 02256 02257 memcpy(receiveBuffer()+offsetInBuffer, buffer, len); 02258 memset(receiveBitmap()+offsetInBuffer, 1, len); 02259 02260 sendToApp(len); 02261 02262 //send back an ack. 02263 sendAckPacket(); 02264 } 02265 else { 02266 //dont ack this packet.. its falling outside the receive window. 02267 sendToApp(0); 02268 } 02269 } 02270 else if (len > 0) { 02271 if ( (int64)(offset+len-1) <= (int64)mLastContiguousByteReceived) { 02272 //printf("Acking packet which we had already received previously\n"); 02273 sendAckPacket(); 02274 } 02275 else if (offsetInBuffer + len <= MAX_RECEIVE_WINDOW) { 02276 assert (offsetInBuffer + len > 0); 02277 02278 mReceiveWindowSize -= len; 02279 02280 memcpy(receiveBuffer()+offsetInBuffer, buffer, len); 02281 memset(receiveBitmap()+offsetInBuffer, 1, len); 02282 02283 sendAckPacket(); 02284 } 02285 else { 02286 //dont ack this packet.. its falling outside the receive window. 02287 sendToApp(0); 02288 } 02289 } 02290 else if (len == 0 && (int64)(offset) == mNextByteExpected) { 02291 // A zero length packet at the next expected offset. This is a keep 02292 // alive, which are just empty packets that we process to keep the 02293 // connection running. Send an ack so we don't end up with unacked 02294 // keep alive packets. 02295 sendAckPacket(); 02296 } 02297 } 02298 02299 //handle any ACKS that might be included in the message... 02300 boost::mutex::scoped_lock lock(mQueueMutex); 02301 02302 bool acked_msgs = false; 02303 if (mChannelToBufferMap.find(offset) != mChannelToBufferMap.end()) { 02304 uint64 dataOffset = mChannelToBufferMap[offset]->mOffset; 02305 mNumOutstandingBytes -= mChannelToBufferMap[offset]->mBufferLength; 02306 02307 mChannelToBufferMap[offset]->mAckTime = curTime; 02308 02309 updateRTO(mChannelToBufferMap[offset]->mTransmitTime, mChannelToBufferMap[offset]->mAckTime); 02310 02311 if ( (int) (pow(2.0, streamMsg->window()) - mNumOutstandingBytes) > 0 ) { 02312 assert( pow(2.0, streamMsg->window()) - mNumOutstandingBytes > 0); 02313 mTransmitWindowSize = pow(2.0, streamMsg->window()) - mNumOutstandingBytes; 02314 } 02315 else { 02316 mTransmitWindowSize = 0; 02317 } 02318 02319 //printf("REMOVED ack packet at offset %d\n", (int)mChannelToBufferMap[offset]->mOffset); 02320 02321 acked_msgs = true; 02322 mChannelToBufferMap.erase(offset); 02323 02324 std::vector <uint64> channelOffsets; 02325 for(std::map<uint64, std::tr1::shared_ptr<StreamBuffer> >::iterator it = mChannelToBufferMap.begin(); 02326 it != mChannelToBufferMap.end(); ++it) 02327 { 02328 if (it->second->mOffset == dataOffset) { 02329 channelOffsets.push_back(it->first); 02330 } 02331 } 02332 02333 for (uint32 i=0; i< channelOffsets.size(); i++) { 02334 mChannelToBufferMap.erase(channelOffsets[i]); 02335 } 02336 } 02337 else { 02338 // ACK received but not found in mChannelToBufferMap 02339 if (mChannelToStreamOffsetMap.find(offset) != mChannelToStreamOffsetMap.end()) { 02340 uint64 dataOffset = mChannelToStreamOffsetMap[offset]; 02341 acked_msgs = true; 02342 mChannelToStreamOffsetMap.erase(offset); 02343 02344 std::vector <uint64> channelOffsets; 02345 for(std::map<uint64, std::tr1::shared_ptr<StreamBuffer> >::iterator it = mChannelToBufferMap.begin(); 02346 it != mChannelToBufferMap.end(); ++it) 02347 { 02348 if (it->second->mOffset == dataOffset) { 02349 channelOffsets.push_back(it->first); 02350 } 02351 } 02352 02353 for (uint32 i=0; i< channelOffsets.size(); i++) { 02354 mChannelToBufferMap.erase(channelOffsets[i]); 02355 } 02356 } 02357 } 02358 02359 // If we acked messages, we've cleared space in the transmit 02360 // buffer (the receiver cleared something out of its receive 02361 // buffer). We can send more data, so schedule servicing if we 02362 // have anything queued. 02363 // TODO(ewencp) maybe only schedule something new when this 02364 // started as a full transmit buffer? Have to be careful about 02365 // this though since mTransmitWindowSize might not be 0 even if it 02366 // was 'full' since the next packet couldn't fit on. 02367 if (acked_msgs && !mQueuedBuffers.empty()) { 02368 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02369 if (conn) { 02370 getContext()->mainStrand->post( 02371 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 02372 "Stream<EndPointType>::serviceStreamNoReturn" 02373 ); 02374 } 02375 } 02376 } 02377 02378 LSID getLSID() { 02379 return mLSID; 02380 } 02381 02382 LSID getRemoteLSID() { 02383 return mRemoteLSID; 02384 } 02385 02386 void updateRTO(Time sampleStartTime, Time sampleEndTime) { 02387 02388 02389 if (sampleStartTime > sampleEndTime ) { 02390 SST_LOG(insane, "Bad sample\n"); 02391 return; 02392 } 02393 02394 if (mFirstRTO) { 02395 mStreamRTOMicroseconds = (sampleEndTime - sampleStartTime).toMicroseconds() ; 02396 mFirstRTO = false; 02397 } 02398 else { 02399 02400 mStreamRTOMicroseconds = FL_ALPHA * mStreamRTOMicroseconds + 02401 (1.0-FL_ALPHA) * (sampleEndTime - sampleStartTime).toMicroseconds(); 02402 } 02403 02404 } 02405 02406 void sendInitPacket(void* data, uint32 len) { 02407 Sirikata::Protocol::SST::SSTStreamHeader sstMsg; 02408 sstMsg.set_lsid( mLSID ); 02409 sstMsg.set_type(sstMsg.INIT); 02410 sstMsg.set_flags(0); 02411 sstMsg.set_window( log((double)mReceiveWindowSize)/log(2.0) ); 02412 sstMsg.set_src_port(mLocalPort); 02413 sstMsg.set_dest_port(mRemotePort); 02414 02415 sstMsg.set_psid(mParentLSID); 02416 02417 sstMsg.set_bsn(0); 02418 02419 sstMsg.set_payload(data, len); 02420 02421 std::string buffer = serializePBJMessage(sstMsg); 02422 02423 02424 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02425 02426 if (!conn) return; 02427 02428 conn->sendData( buffer.data(), buffer.size(), false ); 02429 02430 getContext()->mainStrand->post( 02431 Duration::microseconds(pow(2.0,mNumInitRetransmissions)*mStreamRTOMicroseconds), 02432 std::tr1::bind(&Stream<EndPointType>::serviceStreamNoReturn, this, mWeakThis.lock(), conn), 02433 "Stream<EndPointType>::serviceStreamNoReturn" 02434 ); 02435 02436 } 02437 02438 void sendAckPacket() { 02439 Sirikata::Protocol::SST::SSTStreamHeader sstMsg; 02440 sstMsg.set_lsid( mLSID ); 02441 sstMsg.set_type(sstMsg.ACK); 02442 sstMsg.set_flags(0); 02443 sstMsg.set_window( log((double)mReceiveWindowSize)/log(2.0) ); 02444 sstMsg.set_src_port(mLocalPort); 02445 sstMsg.set_dest_port(mRemotePort); 02446 std::string buffer = serializePBJMessage(sstMsg); 02447 02448 //printf("Sending Ack packet with window %d\n", (int)sstMsg.window()); 02449 02450 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02451 assert(conn); 02452 conn->sendData( buffer.data(), buffer.size(), true); 02453 } 02454 02455 uint64 sendDataPacket(const void* data, uint32 len, uint64 offset) { 02456 Sirikata::Protocol::SST::SSTStreamHeader sstMsg; 02457 sstMsg.set_lsid( mLSID ); 02458 sstMsg.set_type(sstMsg.DATA); 02459 sstMsg.set_flags(0); 02460 sstMsg.set_window( log((double)mReceiveWindowSize)/log(2.0) ); 02461 sstMsg.set_src_port(mLocalPort); 02462 sstMsg.set_dest_port(mRemotePort); 02463 02464 sstMsg.set_bsn(offset); 02465 02466 sstMsg.set_payload(data, len); 02467 02468 std::string buffer = serializePBJMessage(sstMsg); 02469 02470 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02471 assert(conn); 02472 return conn->sendData( buffer.data(), buffer.size(), false); 02473 } 02474 02475 void sendReplyPacket(void* data, uint32 len, LSID remoteLSID) { 02476 Sirikata::Protocol::SST::SSTStreamHeader sstMsg; 02477 sstMsg.set_lsid( mLSID ); 02478 sstMsg.set_type(sstMsg.REPLY); 02479 sstMsg.set_flags(0); 02480 sstMsg.set_window( log((double)mReceiveWindowSize)/log(2.0) ); 02481 sstMsg.set_src_port(mLocalPort); 02482 sstMsg.set_dest_port(mRemotePort); 02483 02484 sstMsg.set_rsid(remoteLSID); 02485 sstMsg.set_bsn(0); 02486 02487 sstMsg.set_payload(data, len); 02488 std::string buffer = serializePBJMessage(sstMsg); 02489 02490 std::tr1::shared_ptr<Connection<EndPointType> > conn = mConnection.lock(); 02491 assert(conn); 02492 conn->sendData( buffer.data(), buffer.size(), false); 02493 } 02494 02495 uint8 mState; 02496 02497 uint32 mLocalPort; 02498 uint32 mRemotePort; 02499 02500 uint64 mNumBytesSent; 02501 02502 LSID mParentLSID; 02503 02504 //weak_ptr to avoid circular dependency between Connection and Stream classes 02505 std::tr1::weak_ptr<Connection<EndPointType> > mConnection; 02506 const Context* mContext; 02507 02508 std::map<uint64, std::tr1::shared_ptr<StreamBuffer> > mChannelToBufferMap; 02509 std::map<uint64, uint32> mChannelToStreamOffsetMap; 02510 02511 std::deque< std::tr1::shared_ptr<StreamBuffer> > mQueuedBuffers; 02512 uint32 mCurrentQueueLength; 02513 02514 USID mUSID; 02515 LSID mLSID; 02516 LSID mRemoteLSID; 02517 02518 uint16 MAX_PAYLOAD_SIZE; 02519 uint32 MAX_QUEUE_LENGTH; 02520 uint32 MAX_RECEIVE_WINDOW; 02521 02522 boost::mutex mQueueMutex; 02523 02524 bool mFirstRTO; 02525 int64 mStreamRTOMicroseconds; 02526 float FL_ALPHA; 02527 02528 02529 uint32 mTransmitWindowSize; 02530 uint32 mReceiveWindowSize; 02531 uint32 mNumOutstandingBytes; 02532 02533 int64 mNextByteExpected; 02534 int64 mLastContiguousByteReceived; 02535 Time mLastSendTime; 02536 Time mLastReceiveTime; 02537 02538 uint8* mReceiveBuffer; 02539 uint8* mReceiveBitmap; 02540 boost::recursive_mutex mReceiveBufferMutex; 02541 02542 ReadCallback mReadCallback; 02543 StreamReturnCallbackFunction mStreamReturnCallback; 02544 02545 friend class Connection<EndPointType>; 02546 02547 /* Variables required for the initial connection */ 02548 bool mConnected; 02549 uint8* mInitialData; 02550 uint16 mInitialDataLength; 02551 uint8 mNumInitRetransmissions; 02552 uint8 MAX_INIT_RETRANSMISSIONS; 02553 02554 ConnectionVariables<EndPointType>* mSSTConnVars; 02555 02556 std::tr1::weak_ptr<Stream<EndPointType> > mWeakThis; 02557 02560 EndPoint <EndPointType> mLocalEndPoint; 02561 EndPoint <EndPointType> mRemoteEndPoint; 02562 02563 }; 02564 02565 02574 template <class EndPointType> 02575 class ConnectionManager : public Service { 02576 public: 02577 typedef std::tr1::shared_ptr<BaseDatagramLayer<EndPointType> > BaseDatagramLayerPtr; 02578 02579 typedef CallbackTypes<EndPointType> CBTypes; 02580 typedef typename CBTypes::StreamReturnCallbackFunction StreamReturnCallbackFunction; 02581 02582 virtual void start() { 02583 } 02584 02585 virtual void stop() { 02586 Connection<EndPointType>::closeConnections(&mSSTConnVars); 02587 } 02588 02589 ~ConnectionManager() { 02590 Connection<EndPointType>::closeConnections(&mSSTConnVars); 02591 } 02592 02593 bool connectStream(EndPoint <EndPointType> localEndPoint, 02594 EndPoint <EndPointType> remoteEndPoint, 02595 StreamReturnCallbackFunction cb) 02596 { 02597 return Stream<EndPointType>::connectStream(&mSSTConnVars, localEndPoint, remoteEndPoint, cb); 02598 } 02599 02600 // The BaseDatagramLayer is really where the interaction with the underlying 02601 // system happens, and different underlying protocols may require different 02602 // parameters. These need to be instantiated by the client code anyway (to 02603 // generate the interface), so we provide some templatized versions to allow 02604 // a variable number of arguments. 02605 template<typename A1> 02606 BaseDatagramLayerPtr createDatagramLayer(EndPointType endPoint, Context* ctx, A1 a1) { 02607 return BaseDatagramLayer<EndPointType>::createDatagramLayer(&mSSTConnVars, endPoint, ctx, a1); 02608 } 02609 template<typename A1, typename A2> 02610 BaseDatagramLayerPtr createDatagramLayer(EndPointType endPoint, Context* ctx, A1 a1, A2 a2) { 02611 return BaseDatagramLayer<EndPointType>::createDatagramLayer(&mSSTConnVars, endPoint, ctx, a1, a2); 02612 } 02613 template<typename A1, typename A2, typename A3> 02614 BaseDatagramLayerPtr createDatagramLayer(EndPointType endPoint, Context* ctx, A1 a1, A2 a2, A3 a3) { 02615 return BaseDatagramLayer<EndPointType>::createDatagramLayer(&mSSTConnVars, endPoint, ctx, a1, a2, a3); 02616 } 02617 02618 BaseDatagramLayerPtr getDatagramLayer(EndPointType endPoint) { 02619 return mSSTConnVars.getDatagramLayer(endPoint); 02620 } 02621 02622 bool listen(StreamReturnCallbackFunction cb, EndPoint <EndPointType> listeningEndPoint) { 02623 return Stream<EndPointType>::listen(&mSSTConnVars, cb, listeningEndPoint); 02624 } 02625 02626 bool unlisten( EndPoint <EndPointType> listeningEndPoint) { 02627 return Stream<EndPointType>::unlisten(&mSSTConnVars, listeningEndPoint); 02628 } 02629 02630 //Storage class for SST's global variables. 02631 ConnectionVariables<EndPointType> mSSTConnVars; 02632 }; 02633 02634 02635 02636 } // namespace SST 02637 } // namespace Sirikata 02638 02639 #endif