import numpy as np import matplotlib.pyplot as plt from collections import Counter import math import openjij as oj # ============================ # 啟發距離矩陣:10-city TSP # ============================ D = np.array([ [0, 2, 8, 5, 7, 6, 3, 9, 4, 2], [2, 0, 5, 7, 3, 8, 4, 6, 9, 1], [8, 5, 0, 2, 6, 7, 3, 4, 1, 5], [5, 7, 2, 0, 4, 3, 8, 6, 2, 7], [7, 3, 6, 4, 0, 5, 9, 2, 7, 3], [6, 8, 7, 3, 5, 0, 4, 7, 9, 6], [3, 4, 3, 8, 9, 4, 0, 5, 6, 2], [9, 6, 4, 6, 2, 7, 5, 0, 8, 3], [4, 9, 1, 2, 7, 9, 6, 8, 0, 5], [2, 1, 5, 7, 3, 6, 2, 3, 5, 0] ]) N = D.shape[0] penalty = 20 # 罰則 > max(D) 即可 def idx(i, t, N): return i * N + t # ============================ # 建立 QUBO:位置編碼 + 固定起點 0 # ============================ def build_qubo(D, penalty, fix_start=True): N = D.shape[0] Q = {} # 距離項 for t in range(N): t2 = (t + 1) % N for i in range(N): for j in range(N): if i != j: u = idx(i, t, N) v = idx(j, t2, N) Q[(u, v)] = Q.get((u, v), 0) + D[i, j] # 約束A: 每個 time slot 有且只有一個城市 for t in range(N): for i in range(N): u = idx(i, t, N) Q[(u, u)] = Q.get((u, u), 0) - penalty for j in range(i + 1, N): v = idx(j, t, N) Q[(u, v)] = Q.get((u, v), 0) + 2 * penalty # 約束B: 每個城市被訪問一次 for i in range(N): for t in range(N): u = idx(i, t, N) Q[(u, u)] = Q.get((u, u), 0) - penalty for t2 in range(t + 1, N): v = idx(i, t2, N) Q[(u, v)] = Q.get((u, v), 0) + 2 * penalty # 固定起點:t=0 必須是 city 0 if fix_start: big = 9999.0 # 禁止 i!=0 在 t=0 for i in range(1, N): u = idx(i, 0, N) Q[(u, u)] = Q.get((u, u), 0) + big # 獎勵 city 0 在 t=0 u0 = idx(0, 0, N) Q[(u0, u0)] = Q.get((u0, u0), 0) - big return Q, N # ============================ # Route decode & cost # ============================ def decode_route(sample, N): route = [] for t in range(N): city_for_t = None for i in range(N): if sample.get(idx(i, t, N), 0) == 1: city_for_t = i break # 若有問題就塞 -1,方便debug if city_for_t is None: city_for_t = -1 route.append(city_for_t) return route def compute_cost(route, D): N = len(route) total = 0 for k in range(N): a = route[k] b = route[(k + 1) % N] total += D[a, b] return total # ============================ # 簡單 KDE 實作(Gaussian kernel) # ============================ def kde_1d(samples, num_points=200): xs = np.linspace(min(samples), max(samples), num_points) n = len(samples) if n < 2: return xs, np.zeros_like(xs) std = np.std(samples) if std == 0: std = 1.0 # Silverman's rule h = 1.06 * std * (n ** (-1/5)) if h == 0: h = 1.0 ys = [] inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi) for x in xs: s = 0.0 for xi in samples: z = (x - xi) / h s += math.exp(-0.5 * z * z) * inv_sqrt_2pi ys.append(s / (n * h)) return xs, np.array(ys) # ============================ # Route diversity heatmap # ============================ def route_diversity_matrix(routes, N): # freq[i, t] = city i 在位置 t 出現次數 freq = np.zeros((N, N), dtype=float) for route in routes: for t, city in enumerate(route): if 0 <= city < N: freq[city, t] += 1 # normalize to probabilities if len(routes) > 0: freq /= len(routes) return freq # ============================ # 跑 SA or SQA 多次 # ============================ def run_algorithm(name, sampler, Q, D, N, num_runs=20, num_reads=20): all_costs = [] all_routes = [] print(f"\n=== Running {name} ===") for r in range(num_runs): result = sampler.sample_qubo(Q, num_reads=num_reads) best = result.first.sample route = decode_route(best, N) cost = compute_cost(route, D) all_costs.append(cost) all_routes.append(tuple(route)) print(f"{name} Run {r+1:02d}: Route={route}, Cost={cost}") return all_costs, all_routes # ============================ # Main # ============================ if __name__ == "__main__": Q, N = build_qubo(D, penalty, fix_start=True) # Samplers sqa_sampler = oj.SQASampler() sa_sampler = oj.SASampler() # 可自行調整次數與 num_reads NUM_RUNS = 20 NUM_READS = 20 sqa_costs, sqa_routes = run_algorithm("SQA", sqa_sampler, Q, D, N, NUM_RUNS, NUM_READS) sa_costs, sa_routes = run_algorithm("SA", sa_sampler, Q, D, N, NUM_RUNS, NUM_READS) # ---------------------------- # 基本統計 # ---------------------------- print("\n=== Summary ===") print(f"SQA: min={min(sqa_costs)}, max={max(sqa_costs)}, mean={np.mean(sqa_costs):.2f}, std={np.std(sqa_costs):.2f}") print(f"SA : min={min(sa_costs)}, max={max(sa_costs)}, mean={np.mean(sa_costs):.2f}, std={np.std(sa_costs):.2f}") # ============================ # 圖 1: Cost histogram + KDE # ============================ fig1, ax1 = plt.subplots(figsize=(8,5)) bins = range(min(sa_costs + sqa_costs), max(sa_costs + sqa_costs) + 2) ax1.hist(sa_costs, bins=bins, alpha=0.4, label="SA", density=True) ax1.hist(sqa_costs, bins=bins, alpha=0.4, label="SQA", density=True) xs_sa, ys_sa = kde_1d(sa_costs) xs_sqa, ys_sqa = kde_1d(sqa_costs) ax1.plot(xs_sa, ys_sa, label="SA KDE") ax1.plot(xs_sqa, ys_sqa, label="SQA KDE") ax1.set_title("Tour Cost Distribution (SA vs SQA)") ax1.set_xlabel("Tour Cost") ax1.set_ylabel("Density") ax1.legend() fig1.tight_layout() fig1.savefig("tsp_cost_hist_kde_sa_vs_sqa.png") # ============================ # 圖 2: Violin plot + Box plot # ============================ fig2, axs2 = plt.subplots(1, 2, figsize=(10,5)) # Violin axs2[0].violinplot([sa_costs, sqa_costs], positions=[1,2], showmeans=True) axs2[0].set_xticks([1,2]) axs2[0].set_xticklabels(["SA", "SQA"]) axs2[0].set_title("Violin Plot of Tour Costs") # Box plot axs2[1].boxplot([sa_costs, sqa_costs], labels=["SA", "SQA"], showmeans=True) axs2[1].set_title("Box Plot of Tour Costs") fig2.tight_layout() fig2.savefig("tsp_cost_violin_box_sa_vs_sqa.png") # ============================ # 圖 3: Route diversity heatmap # ============================ sqa_mat = route_diversity_matrix(sqa_routes, N) sa_mat = route_diversity_matrix(sa_routes, N) fig3, axs3 = plt.subplots(1, 2, figsize=(12,5)) im0 = axs3[0].imshow(sa_mat, aspect='auto', origin='lower') axs3[0].set_title("Route Diversity (SA)") axs3[0].set_xlabel("Position t") axs3[0].set_ylabel("City index") fig3.colorbar(im0, ax=axs3[0]) im1 = axs3[1].imshow(sqa_mat, aspect='auto', origin='lower') axs3[1].set_title("Route Diversity (SQA)") axs3[1].set_xlabel("Position t") axs3[1].set_ylabel("City index") fig3.colorbar(im1, ax=axs3[1]) fig3.tight_layout() fig3.savefig("tsp_route_diversity_heatmap_sa_vs_sqa.png") plt.show()