You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
7.3 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import math
import openjij as oj
# ============================
# 啟發距離矩陣10-city TSP
# ============================
D = np.array([
[0, 2, 8, 5, 7, 6, 3, 9, 4, 2],
[2, 0, 5, 7, 3, 8, 4, 6, 9, 1],
[8, 5, 0, 2, 6, 7, 3, 4, 1, 5],
[5, 7, 2, 0, 4, 3, 8, 6, 2, 7],
[7, 3, 6, 4, 0, 5, 9, 2, 7, 3],
[6, 8, 7, 3, 5, 0, 4, 7, 9, 6],
[3, 4, 3, 8, 9, 4, 0, 5, 6, 2],
[9, 6, 4, 6, 2, 7, 5, 0, 8, 3],
[4, 9, 1, 2, 7, 9, 6, 8, 0, 5],
[2, 1, 5, 7, 3, 6, 2, 3, 5, 0]
])
N = D.shape[0]
penalty = 20 # 罰則 > max(D) 即可
def idx(i, t, N):
return i * N + t
# ============================
# 建立 QUBO位置編碼 + 固定起點 0
# ============================
def build_qubo(D, penalty, fix_start=True):
N = D.shape[0]
Q = {}
# 距離項
for t in range(N):
t2 = (t + 1) % N
for i in range(N):
for j in range(N):
if i != j:
u = idx(i, t, N)
v = idx(j, t2, N)
Q[(u, v)] = Q.get((u, v), 0) + D[i, j]
# 約束A: 每個 time slot 有且只有一個城市
for t in range(N):
for i in range(N):
u = idx(i, t, N)
Q[(u, u)] = Q.get((u, u), 0) - penalty
for j in range(i + 1, N):
v = idx(j, t, N)
Q[(u, v)] = Q.get((u, v), 0) + 2 * penalty
# 約束B: 每個城市被訪問一次
for i in range(N):
for t in range(N):
u = idx(i, t, N)
Q[(u, u)] = Q.get((u, u), 0) - penalty
for t2 in range(t + 1, N):
v = idx(i, t2, N)
Q[(u, v)] = Q.get((u, v), 0) + 2 * penalty
# 固定起點t=0 必須是 city 0
if fix_start:
big = 9999.0
# 禁止 i!=0 在 t=0
for i in range(1, N):
u = idx(i, 0, N)
Q[(u, u)] = Q.get((u, u), 0) + big
# 獎勵 city 0 在 t=0
u0 = idx(0, 0, N)
Q[(u0, u0)] = Q.get((u0, u0), 0) - big
return Q, N
# ============================
# Route decode & cost
# ============================
def decode_route(sample, N):
route = []
for t in range(N):
city_for_t = None
for i in range(N):
if sample.get(idx(i, t, N), 0) == 1:
city_for_t = i
break
# 若有問題就塞 -1方便debug
if city_for_t is None:
city_for_t = -1
route.append(city_for_t)
return route
def compute_cost(route, D):
N = len(route)
total = 0
for k in range(N):
a = route[k]
b = route[(k + 1) % N]
total += D[a, b]
return total
# ============================
# 簡單 KDE 實作Gaussian kernel
# ============================
def kde_1d(samples, num_points=200):
xs = np.linspace(min(samples), max(samples), num_points)
n = len(samples)
if n < 2:
return xs, np.zeros_like(xs)
std = np.std(samples)
if std == 0:
std = 1.0
# Silverman's rule
h = 1.06 * std * (n ** (-1/5))
if h == 0:
h = 1.0
ys = []
inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi)
for x in xs:
s = 0.0
for xi in samples:
z = (x - xi) / h
s += math.exp(-0.5 * z * z) * inv_sqrt_2pi
ys.append(s / (n * h))
return xs, np.array(ys)
# ============================
# Route diversity heatmap
# ============================
def route_diversity_matrix(routes, N):
# freq[i, t] = city i 在位置 t 出現次數
freq = np.zeros((N, N), dtype=float)
for route in routes:
for t, city in enumerate(route):
if 0 <= city < N:
freq[city, t] += 1
# normalize to probabilities
if len(routes) > 0:
freq /= len(routes)
return freq
# ============================
# 跑 SA or SQA 多次
# ============================
def run_algorithm(name, sampler, Q, D, N, num_runs=20, num_reads=20):
all_costs = []
all_routes = []
print(f"\n=== Running {name} ===")
for r in range(num_runs):
result = sampler.sample_qubo(Q, num_reads=num_reads)
best = result.first.sample
route = decode_route(best, N)
cost = compute_cost(route, D)
all_costs.append(cost)
all_routes.append(tuple(route))
print(f"{name} Run {r+1:02d}: Route={route}, Cost={cost}")
return all_costs, all_routes
# ============================
# Main
# ============================
if __name__ == "__main__":
Q, N = build_qubo(D, penalty, fix_start=True)
# Samplers
sqa_sampler = oj.SQASampler()
sa_sampler = oj.SASampler()
# 可自行調整次數與 num_reads
NUM_RUNS = 20
NUM_READS = 20
sqa_costs, sqa_routes = run_algorithm("SQA", sqa_sampler, Q, D, N, NUM_RUNS, NUM_READS)
sa_costs, sa_routes = run_algorithm("SA", sa_sampler, Q, D, N, NUM_RUNS, NUM_READS)
# ----------------------------
# 基本統計
# ----------------------------
print("\n=== Summary ===")
print(f"SQA: min={min(sqa_costs)}, max={max(sqa_costs)}, mean={np.mean(sqa_costs):.2f}, std={np.std(sqa_costs):.2f}")
print(f"SA : min={min(sa_costs)}, max={max(sa_costs)}, mean={np.mean(sa_costs):.2f}, std={np.std(sa_costs):.2f}")
# ============================
# 圖 1: Cost histogram + KDE
# ============================
fig1, ax1 = plt.subplots(figsize=(8,5))
bins = range(min(sa_costs + sqa_costs), max(sa_costs + sqa_costs) + 2)
ax1.hist(sa_costs, bins=bins, alpha=0.4, label="SA", density=True)
ax1.hist(sqa_costs, bins=bins, alpha=0.4, label="SQA", density=True)
xs_sa, ys_sa = kde_1d(sa_costs)
xs_sqa, ys_sqa = kde_1d(sqa_costs)
ax1.plot(xs_sa, ys_sa, label="SA KDE")
ax1.plot(xs_sqa, ys_sqa, label="SQA KDE")
ax1.set_title("Tour Cost Distribution (SA vs SQA)")
ax1.set_xlabel("Tour Cost")
ax1.set_ylabel("Density")
ax1.legend()
fig1.tight_layout()
fig1.savefig("tsp_cost_hist_kde_sa_vs_sqa.png")
# ============================
# 圖 2: Violin plot + Box plot
# ============================
fig2, axs2 = plt.subplots(1, 2, figsize=(10,5))
# Violin
axs2[0].violinplot([sa_costs, sqa_costs], positions=[1,2], showmeans=True)
axs2[0].set_xticks([1,2])
axs2[0].set_xticklabels(["SA", "SQA"])
axs2[0].set_title("Violin Plot of Tour Costs")
# Box plot
axs2[1].boxplot([sa_costs, sqa_costs], labels=["SA", "SQA"], showmeans=True)
axs2[1].set_title("Box Plot of Tour Costs")
fig2.tight_layout()
fig2.savefig("tsp_cost_violin_box_sa_vs_sqa.png")
# ============================
# 圖 3: Route diversity heatmap
# ============================
sqa_mat = route_diversity_matrix(sqa_routes, N)
sa_mat = route_diversity_matrix(sa_routes, N)
fig3, axs3 = plt.subplots(1, 2, figsize=(12,5))
im0 = axs3[0].imshow(sa_mat, aspect='auto', origin='lower')
axs3[0].set_title("Route Diversity (SA)")
axs3[0].set_xlabel("Position t")
axs3[0].set_ylabel("City index")
fig3.colorbar(im0, ax=axs3[0])
im1 = axs3[1].imshow(sqa_mat, aspect='auto', origin='lower')
axs3[1].set_title("Route Diversity (SQA)")
axs3[1].set_xlabel("Position t")
axs3[1].set_ylabel("City index")
fig3.colorbar(im1, ax=axs3[1])
fig3.tight_layout()
fig3.savefig("tsp_route_diversity_heatmap_sa_vs_sqa.png")
plt.show()