You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
7.3 KiB
Python

5 months ago
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import math
import openjij as oj
# ============================
# 啟發距離矩陣10-city TSP
# ============================
D = np.array([
[0, 2, 8, 5, 7, 6, 3, 9, 4, 2],
[2, 0, 5, 7, 3, 8, 4, 6, 9, 1],
[8, 5, 0, 2, 6, 7, 3, 4, 1, 5],
[5, 7, 2, 0, 4, 3, 8, 6, 2, 7],
[7, 3, 6, 4, 0, 5, 9, 2, 7, 3],
[6, 8, 7, 3, 5, 0, 4, 7, 9, 6],
[3, 4, 3, 8, 9, 4, 0, 5, 6, 2],
[9, 6, 4, 6, 2, 7, 5, 0, 8, 3],
[4, 9, 1, 2, 7, 9, 6, 8, 0, 5],
[2, 1, 5, 7, 3, 6, 2, 3, 5, 0]
])
N = D.shape[0]
penalty = 20 # 罰則 > max(D) 即可
def idx(i, t, N):
return i * N + t
# ============================
# 建立 QUBO位置編碼 + 固定起點 0
# ============================
def build_qubo(D, penalty, fix_start=True):
N = D.shape[0]
Q = {}
# 距離項
for t in range(N):
t2 = (t + 1) % N
for i in range(N):
for j in range(N):
if i != j:
u = idx(i, t, N)
v = idx(j, t2, N)
Q[(u, v)] = Q.get((u, v), 0) + D[i, j]
# 約束A: 每個 time slot 有且只有一個城市
for t in range(N):
for i in range(N):
u = idx(i, t, N)
Q[(u, u)] = Q.get((u, u), 0) - penalty
for j in range(i + 1, N):
v = idx(j, t, N)
Q[(u, v)] = Q.get((u, v), 0) + 2 * penalty
# 約束B: 每個城市被訪問一次
for i in range(N):
for t in range(N):
u = idx(i, t, N)
Q[(u, u)] = Q.get((u, u), 0) - penalty
for t2 in range(t + 1, N):
v = idx(i, t2, N)
Q[(u, v)] = Q.get((u, v), 0) + 2 * penalty
# 固定起點t=0 必須是 city 0
if fix_start:
big = 9999.0
# 禁止 i!=0 在 t=0
for i in range(1, N):
u = idx(i, 0, N)
Q[(u, u)] = Q.get((u, u), 0) + big
# 獎勵 city 0 在 t=0
u0 = idx(0, 0, N)
Q[(u0, u0)] = Q.get((u0, u0), 0) - big
return Q, N
# ============================
# Route decode & cost
# ============================
def decode_route(sample, N):
route = []
for t in range(N):
city_for_t = None
for i in range(N):
if sample.get(idx(i, t, N), 0) == 1:
city_for_t = i
break
# 若有問題就塞 -1方便debug
if city_for_t is None:
city_for_t = -1
route.append(city_for_t)
return route
def compute_cost(route, D):
N = len(route)
total = 0
for k in range(N):
a = route[k]
b = route[(k + 1) % N]
total += D[a, b]
return total
# ============================
# 簡單 KDE 實作Gaussian kernel
# ============================
def kde_1d(samples, num_points=200):
xs = np.linspace(min(samples), max(samples), num_points)
n = len(samples)
if n < 2:
return xs, np.zeros_like(xs)
std = np.std(samples)
if std == 0:
std = 1.0
# Silverman's rule
h = 1.06 * std * (n ** (-1/5))
if h == 0:
h = 1.0
ys = []
inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi)
for x in xs:
s = 0.0
for xi in samples:
z = (x - xi) / h
s += math.exp(-0.5 * z * z) * inv_sqrt_2pi
ys.append(s / (n * h))
return xs, np.array(ys)
# ============================
# Route diversity heatmap
# ============================
def route_diversity_matrix(routes, N):
# freq[i, t] = city i 在位置 t 出現次數
freq = np.zeros((N, N), dtype=float)
for route in routes:
for t, city in enumerate(route):
if 0 <= city < N:
freq[city, t] += 1
# normalize to probabilities
if len(routes) > 0:
freq /= len(routes)
return freq
# ============================
# 跑 SA or SQA 多次
# ============================
def run_algorithm(name, sampler, Q, D, N, num_runs=20, num_reads=20):
all_costs = []
all_routes = []
print(f"\n=== Running {name} ===")
for r in range(num_runs):
result = sampler.sample_qubo(Q, num_reads=num_reads)
best = result.first.sample
route = decode_route(best, N)
cost = compute_cost(route, D)
all_costs.append(cost)
all_routes.append(tuple(route))
print(f"{name} Run {r+1:02d}: Route={route}, Cost={cost}")
return all_costs, all_routes
# ============================
# Main
# ============================
if __name__ == "__main__":
Q, N = build_qubo(D, penalty, fix_start=True)
# Samplers
sqa_sampler = oj.SQASampler()
sa_sampler = oj.SASampler()
# 可自行調整次數與 num_reads
NUM_RUNS = 20
NUM_READS = 20
sqa_costs, sqa_routes = run_algorithm("SQA", sqa_sampler, Q, D, N, NUM_RUNS, NUM_READS)
sa_costs, sa_routes = run_algorithm("SA", sa_sampler, Q, D, N, NUM_RUNS, NUM_READS)
# ----------------------------
# 基本統計
# ----------------------------
print("\n=== Summary ===")
print(f"SQA: min={min(sqa_costs)}, max={max(sqa_costs)}, mean={np.mean(sqa_costs):.2f}, std={np.std(sqa_costs):.2f}")
print(f"SA : min={min(sa_costs)}, max={max(sa_costs)}, mean={np.mean(sa_costs):.2f}, std={np.std(sa_costs):.2f}")
# ============================
# 圖 1: Cost histogram + KDE
# ============================
fig1, ax1 = plt.subplots(figsize=(8,5))
bins = range(min(sa_costs + sqa_costs), max(sa_costs + sqa_costs) + 2)
ax1.hist(sa_costs, bins=bins, alpha=0.4, label="SA", density=True)
ax1.hist(sqa_costs, bins=bins, alpha=0.4, label="SQA", density=True)
xs_sa, ys_sa = kde_1d(sa_costs)
xs_sqa, ys_sqa = kde_1d(sqa_costs)
ax1.plot(xs_sa, ys_sa, label="SA KDE")
ax1.plot(xs_sqa, ys_sqa, label="SQA KDE")
ax1.set_title("Tour Cost Distribution (SA vs SQA)")
ax1.set_xlabel("Tour Cost")
ax1.set_ylabel("Density")
ax1.legend()
fig1.tight_layout()
fig1.savefig("tsp_cost_hist_kde_sa_vs_sqa.png")
# ============================
# 圖 2: Violin plot + Box plot
# ============================
fig2, axs2 = plt.subplots(1, 2, figsize=(10,5))
# Violin
axs2[0].violinplot([sa_costs, sqa_costs], positions=[1,2], showmeans=True)
axs2[0].set_xticks([1,2])
axs2[0].set_xticklabels(["SA", "SQA"])
axs2[0].set_title("Violin Plot of Tour Costs")
# Box plot
axs2[1].boxplot([sa_costs, sqa_costs], labels=["SA", "SQA"], showmeans=True)
axs2[1].set_title("Box Plot of Tour Costs")
fig2.tight_layout()
fig2.savefig("tsp_cost_violin_box_sa_vs_sqa.png")
# ============================
# 圖 3: Route diversity heatmap
# ============================
sqa_mat = route_diversity_matrix(sqa_routes, N)
sa_mat = route_diversity_matrix(sa_routes, N)
fig3, axs3 = plt.subplots(1, 2, figsize=(12,5))
im0 = axs3[0].imshow(sa_mat, aspect='auto', origin='lower')
axs3[0].set_title("Route Diversity (SA)")
axs3[0].set_xlabel("Position t")
axs3[0].set_ylabel("City index")
fig3.colorbar(im0, ax=axs3[0])
im1 = axs3[1].imshow(sqa_mat, aspect='auto', origin='lower')
axs3[1].set_title("Route Diversity (SQA)")
axs3[1].set_xlabel("Position t")
axs3[1].set_ylabel("City index")
fig3.colorbar(im1, ax=axs3[1])
fig3.tight_layout()
fig3.savefig("tsp_route_diversity_heatmap_sa_vs_sqa.png")
plt.show()