diff etc/tftp.lua @ 0:4b915342e2a8

LuaSocket 2.0.2 + CMake build description.
author Eric Wing <ewing . public |-at-| gmail . com>
date Tue, 26 Aug 2008 18:40:01 -0700
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/etc/tftp.lua	Tue Aug 26 18:40:01 2008 -0700
@@ -0,0 +1,155 @@
+-----------------------------------------------------------------------------
+-- TFTP support for the Lua language
+-- LuaSocket toolkit.
+-- Author: Diego Nehab
+-- RCS ID: $Id: tftp.lua,v 1.16 2005/11/22 08:33:29 diego Exp $
+-----------------------------------------------------------------------------
+
+-----------------------------------------------------------------------------
+-- Load required files
+-----------------------------------------------------------------------------
+local base = _G
+local table = require("table")
+local math = require("math")
+local string = require("string")
+local socket = require("socket")
+local ltn12 = require("ltn12")
+local url = require("socket.url")
+module("socket.tftp")
+
+-----------------------------------------------------------------------------
+-- Program constants
+-----------------------------------------------------------------------------
+local char = string.char
+local byte = string.byte
+
+PORT = 69
+local OP_RRQ = 1
+local OP_WRQ = 2
+local OP_DATA = 3
+local OP_ACK = 4
+local OP_ERROR = 5
+local OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"}
+
+-----------------------------------------------------------------------------
+-- Packet creation functions
+-----------------------------------------------------------------------------
+local function RRQ(source, mode)
+	return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0)
+end
+
+local function WRQ(source, mode)
+	return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0)
+end
+
+local function ACK(block)
+	local low, high
+	low = math.mod(block, 256)
+	high = (block - low)/256
+	return char(0, OP_ACK, high, low)
+end
+
+local function get_OP(dgram)
+    local op = byte(dgram, 1)*256 + byte(dgram, 2)
+    return op
+end
+
+-----------------------------------------------------------------------------
+-- Packet analysis functions
+-----------------------------------------------------------------------------
+local function split_DATA(dgram)
+	local block = byte(dgram, 3)*256 + byte(dgram, 4)
+	local data = string.sub(dgram, 5)
+	return block, data
+end
+
+local function get_ERROR(dgram)
+	local code = byte(dgram, 3)*256 + byte(dgram, 4)
+	local msg
+	_,_, msg = string.find(dgram, "(.*)\000", 5)
+	return string.format("error code %d: %s", code, msg)
+end
+
+-----------------------------------------------------------------------------
+-- The real work
+-----------------------------------------------------------------------------
+local function tget(gett)
+    local retries, dgram, sent, datahost, dataport, code
+    local last = 0
+    socket.try(gett.host, "missing host")
+	local con = socket.try(socket.udp())
+    local try = socket.newtry(function() con:close() end)
+    -- convert from name to ip if needed
+	gett.host = try(socket.dns.toip(gett.host))
+	con:settimeout(1)
+    -- first packet gives data host/port to be used for data transfers
+    local path = string.gsub(gett.path or "", "^/", "")
+    path = url.unescape(path)
+    retries = 0
+	repeat
+		sent = try(con:sendto(RRQ(path, "octet"), gett.host, gett.port))
+		dgram, datahost, dataport = con:receivefrom()
+        retries = retries + 1
+	until dgram or datahost ~= "timeout" or retries > 5
+	try(dgram, datahost)
+    -- associate socket with data host/port
+	try(con:setpeername(datahost, dataport))
+    -- default sink
+    local sink = gett.sink or ltn12.sink.null()
+    -- process all data packets
+	while 1 do
+        -- decode packet
+		code = get_OP(dgram)
+		try(code ~= OP_ERROR, get_ERROR(dgram))
+        try(code == OP_DATA, "unhandled opcode " .. code)
+        -- get data packet parts
+		local block, data = split_DATA(dgram)
+        -- if not repeated, write
+        if block == last+1 then
+		    try(sink(data))
+            last = block
+        end
+        -- last packet brings less than 512 bytes of data
+		if string.len(data) < 512 then
+            try(con:send(ACK(block)))
+            try(con:close())
+            try(sink(nil))
+            return 1
+        end
+        -- get the next packet
+        retries = 0
+		repeat
+			sent = try(con:send(ACK(last)))
+			dgram, err = con:receive()
+            retries = retries + 1
+		until dgram or err ~= "timeout" or retries > 5
+		try(dgram, err)
+	end
+end
+
+local default = {
+    port = PORT,
+    path ="/",
+    scheme = "tftp"
+}
+
+local function parse(u)
+    local t = socket.try(url.parse(u, default))
+    socket.try(t.scheme == "tftp", "invalid scheme '" .. t.scheme .. "'")
+    socket.try(t.host, "invalid host")
+    return t
+end
+
+local function sget(u)
+    local gett = parse(u)
+    local t = {}
+    gett.sink = ltn12.sink.table(t)
+    tget(gett)
+    return table.concat(t)
+end
+
+get = socket.protect(function(gett)
+    if base.type(gett) == "string" then return sget(gett)
+    else return tget(gett) end
+end)
+