summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMiklós Fazekas <mfazekas@szemafor.com>2010-02-28 10:23:55 +0100
committerMiklós Fazekas <mfazekas@szemafor.com>2010-02-28 10:32:45 +0100
commitaba7f6d18dbf58be1896e01b3b8f0d648459e359 (patch)
tree6b9ebaee8ba0f4f2d8a00bb9dcbd35e7d6c4e3f0
parent40bc65a0f4c58bb113688ff7d238b7bcd6b586ca (diff)
downloadnet-ssh-aba7f6d18dbf58be1896e01b3b8f0d648459e359.tar.gz
Fixed ssh port forward closing bug, by catching closed, conn_reset exceptions, and acting as the socket was closed: todo do it the proper way
-rw-r--r--endtoendtests/test_forward.rb39
-rw-r--r--lib/net/ssh/service/forward.rb35
2 files changed, 61 insertions, 13 deletions
diff --git a/endtoendtests/test_forward.rb b/endtoendtests/test_forward.rb
index ba8e519..377720b 100644
--- a/endtoendtests/test_forward.rb
+++ b/endtoendtests/test_forward.rb
@@ -23,30 +23,39 @@ class TestForward < Test::Unit::TestCase
8080
end
- def start_server_sending_lot_of_data
+ def start_server_sending_lot_of_data(exceptions=nil)
server = TCPServer.open(0)
Thread.start do
loop do
Thread.start(server.accept) do |client|
- 10000.times do |i|
- client.puts "item#{i}"
+ begin
+ 10000.times do |i|
+ client.puts "item#{i}"
+ end
+ client.close
+ rescue
+ exceptions << $!
+ raise
end
- client.close
end
end
end
return server
end
- def start_server_closing_soon
+ def start_server_closing_soon(exceptions=nil)
server = TCPServer.open(0)
Thread.start do
loop do
Thread.start(server.accept) do |client|
- client.recv(1024)
- client.setsockopt(Socket::SOL_SOCKET, Socket::SO_LINGER, [1, 0].pack("ii"))
- client.close
- #client.close
+ begin
+ client.recv(1024)
+ client.setsockopt(Socket::SOL_SOCKET, Socket::SO_LINGER, [1, 0].pack("ii"))
+ client.close
+ rescue
+ exceptions << $!
+ raise
+ end
end
end
end
@@ -54,8 +63,9 @@ class TestForward < Test::Unit::TestCase
end
def test_loop_should_not_abort_when_local_side_of_forward_is_closed
- session = Net::SSH.start(*ssh_start_params)
- server = start_server_sending_lot_of_data
+ session = Net::SSH.start(*ssh_start_params)
+ server_exc = Queue.new
+ server = start_server_sending_lot_of_data(server_exc)
remote_port = server.addr[1]
local_port = find_free_port
session.forward.local(local_port, localhost, remote_port)
@@ -71,11 +81,13 @@ class TestForward < Test::Unit::TestCase
end
end
session.loop(0.1) { client_done.empty? }
+ assert_equal "Broken pipe", "#{server_exc.pop}"
end
def test_loop_should_not_abort_when_local_side_of_forward_is_reset
- session = Net::SSH.start(*ssh_start_params)
- server = start_server_sending_lot_of_data
+ session = Net::SSH.start(*ssh_start_params)
+ server_exc = Queue.new
+ server = start_server_sending_lot_of_data(server_exc)
remote_port = server.addr[1]
local_port = find_free_port+1
session.forward.local(local_port, localhost, remote_port)
@@ -92,6 +104,7 @@ class TestForward < Test::Unit::TestCase
end
end
session.loop(0.1) { client_done.empty? }
+ assert_equal "Broken pipe", "#{server_exc.pop}"
end
def test_loop_should_not_abort_when_server_side_of_forward_is_closed
diff --git a/lib/net/ssh/service/forward.rb b/lib/net/ssh/service/forward.rb
index 6df93ea..4338f91 100644
--- a/lib/net/ssh/service/forward.rb
+++ b/lib/net/ssh/service/forward.rb
@@ -193,12 +193,47 @@ module Net; module SSH; module Service
end
private
+
+ module ForwardedBufferedIo
+ def fill(n=8192)
+ begin
+ super(n)
+ rescue Errno::ECONNRESET => e
+ debug { "connection was reset => shallowing exception:#{e}" }
+ return 0
+ rescue IOError => e
+ if e.message =~ /closed/ then
+ debug { "connection was reset => shallowing exception:#{e}" }
+ return 0
+ else
+ raise
+ end
+ end
+ end
+
+ def send_pending
+ begin
+ super
+ rescue Errno::ECONNRESET => e
+ debug { "connection was reset => shallowing exception:#{e}" }
+ return 0
+ rescue IOError => e
+ if e.message =~ /closed/ then
+ debug { "connection was reset => shallowing exception:#{e}" }
+ return 0
+ else
+ raise
+ end
+ end
+ end
+ end
# Perform setup operations that are common to all forwarded channels.
# +client+ is a socket, +channel+ is the channel that was just created,
# and +type+ is an arbitrary string describing the type of the channel.
def prepare_client(client, channel, type)
client.extend(Net::SSH::BufferedIo)
+ client.extend(ForwardedBufferedIo)
client.logger = logger
session.listen_to(client)