本文简单实现了最短增广路径算法
首先我们简单实现 queue(队列) 数据结构 :
local queue = {}
queue.__index = queue
function queue:push(val)
table.insert(self.data, val)
end
function queue:pop()
if #self.data > 0 then
return table.remove(self.data, 1)
end
end
function queue:clear()
self.data = {}
end
function queue:empty()
return #self.data <= 0
end
function queue.create()
local t = { data = {} }
return setmetatable(t, queue)
end
通过给定的网络结构(使用邻接表表示),生成对应的实流网络和残余网络 :
function prepare_net(define_net)
if define_net then
local flow_net = {}
local residual_net = {}
-- handle flow net
for node, next_nodes in pairs(define_net) do
flow_net[node] = {}
for next_node, capacity in pairs(next_nodes) do
-- if check for handling ill settings
if node ~= next_node and capacity > 0 then
flow_net[node][next_node] = 0
end
end
end
-- handle residual net
for node, next_nodes in pairs(define_net) do
residual_net[node] = residual_net[node] or {}
for next_node, capacity in pairs(next_nodes) do
-- if check for handling ill settings
if node ~= next_node and capacity > 0 then
residual_net[node][next_node] = capacity
residual_net[next_node] = residual_net[next_node] or {}
residual_net[next_node][node] = 0
end
end
end
return flow_net, residual_net
end
end
接着就是实际的算法实现了,基本思路就是迭代查找可增广路径,并调整路径流量:
function max_flow_net(define_net, src_node, dst_node)
-- ill case handling
if not define_net or src_node == dst_node then
return
end
local flow_net, residual_net = prepare_net(define_net)
if not flow_net or not residual_net then
return
end
local queue_nodes = queue.create()
queue_nodes:clear()
queue_nodes:push(src_node)
local visited_nodes = { [src_node] = true }
local parent_nodes = {}
while not queue_nodes:empty() do
local cur_node = queue_nodes:pop()
local next_nodes = residual_net[cur_node]
for node, flow in pairs(next_nodes) do
if flow > 0 and not visited_nodes[node] then
if node == dst_node then
-- do net handling
-- find flow path
local flow_path = { node }
local parent_node = cur_node
while parent_node do
table.insert(flow_path, 1, parent_node)
parent_node = parent_nodes[parent_node]
end
-- get min flow
local min_flow = math.huge
for i = 1, #flow_path - 1 do
local node_1 = flow_path[i]
local node_2 = flow_path[i + 1]
local edge_flow = residual_net[node_1][node_2]
if edge_flow < min_flow then
min_flow = edge_flow
end
end
-- error handling
if min_flow >= math.huge then
print("error to get min flow of path")
return
end
-- adjust residual net and flow net
for i = 1, #flow_path - 1 do
local node_1 = flow_path[i]
local node_2 = flow_path[i + 1]
residual_net[node_1][node_2] = residual_net[node_1][node_2] - min_flow
residual_net[node_2][node_1] = residual_net[node_2][node_1] + min_flow
if define_net[node_1][node_2] and define_net[node_1][node_2] > 0 then
flow_net[node_1][node_2] = flow_net[node_1][node_2] + min_flow
else
flow_net[node_1][node_2] = flow_net[node_1][node_2] - min_flow
end
end
-- do another round searching
queue_nodes:clear()
queue_nodes:push(src_node)
visited_nodes = { [src_node] = true }
parent_nodes = {}
break
else
visited_nodes[node] = true
parent_nodes[node] = cur_node
queue_nodes:push(node)
end
end
end
end
return flow_net
end
给出如下网络:
local define_net =
{
[1] =
{
[2] = 12,
[3] = 10,
},
[2] =
{
[4] = 8,
},
[3] =
{
[2] = 2,
[5] = 13,
},
[4] =
{
[3] = 5,
[6] = 18,
},
[5] =
{
[4] = 6,
[6] = 4,
},
[6] =
{
},
}
对应的图形表示如下:
有兴趣的朋友可以看看以下代码的结果:
max_flow_net(define_net, 1, 6)