Source code for symforce.examples.robot_3d_localization.plotting

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from matplotlib.widgets import Slider

from symforce.opt.optimizer import Optimizer
from symforce.python_util import AttrDict
from symforce.values import Values


[docs]def plot_solution( optimizer: Optimizer, result: Optimizer.Result, animated: bool = False, show_iteration_text: bool = False, ) -> None: """ Visualize the optimization problem along its iterations. If animated is True, displays a matplotlib animation instead of providing an interactive slider. """ # Pull out values from the result values_per_iter = [optimizer.load_iteration_values(stats.values) for stats in result.iterations] # Create the layout fig = plt.figure() ax = plt.axes(projection="3d") ax.set_zlim3d(0, 11) ax.set_ylim3d(0, 11) ax.set_xlim3d(0, 11) plt.grid(True, alpha=0.5) plt.tight_layout() # Pull out quantities to plot data = get_data_to_plot(result.optimized_values) # Draw a circle at the origin ax.scatter3D([0], [0], [0], color="black", s=50, zorder=2) # Draw landmark locations ax.scatter3D( data.world_t_landmark[:, 0], data.world_t_landmark[:, 1], data.world_t_landmark[:, 2], color="orange", s=100, zorder=3, ) # Draw poses poses_circles = [ ax.scatter3D( data.world_t_body[:, 0], data.world_t_body[:, 1], data.world_t_body[:, 2], color="skyblue", zorder=3, s=100, ) ] # Draw lines connecting poses poses_lines = ax.plot3D( data.world_t_body[:, 0], data.world_t_body[:, 1], data.world_t_body[:, 2], color="black", zorder=2, alpha=0.8, ) # Draw dotted lines from poses to their landmark heading measurements dotted_style = dict( arrow_length_ratio=0.05, linestyle=":", linewidth=0.4, color="red", alpha=0.8, capstyle="butt", ) meas_arrows = [ ax.quiver( data.world_t_body[:, 0], data.world_t_body[:, 1], data.world_t_body[:, 2], data.meas_vectors[:, landmark_inx, 0], data.meas_vectors[:, landmark_inx, 1], data.meas_vectors[:, landmark_inx, 2], **dotted_style, ) for landmark_inx in range(data.meas_vectors.shape[1]) ] # Text box to write iteration stats if show_iteration_text: text = ax.text(8, 7, 9, "-", color="black") def update_plot(slider_value: np.float64) -> None: """ Update the plot using the given iteration. """ num = int(slider_value) # Set iteration text and abort if we rejected this iteration if show_iteration_text: stats = result.iterations[num] if num > 0 and not stats.update_accepted: text.set_text(f"Iteration: {num} (rejected)\nError: {stats.new_error:.1f}") return text.set_text(f"Iteration: {num}\nError: {stats.new_error:.1f}") # Get plottable data for this iteration v = values_per_iter[num] data = get_data_to_plot(v) # Update the pose locations and connecting lines poses_circles[0].remove() poses_circles[0] = ax.scatter3D( data.world_t_body[:, 0], data.world_t_body[:, 1], data.world_t_body[:, 2], color="skyblue", zorder=3, s=100, ) poses_lines[0].remove() poses_lines[0] = ax.plot3D( data.world_t_body[:, 0], data.world_t_body[:, 1], data.world_t_body[:, 2], color="black", zorder=2, alpha=0.8, )[0] # Update measurement vectors to landmarks for landmark_inx in range(len(v["world_t_landmark"])): meas_arrows[landmark_inx].remove() meas_arrows[landmark_inx] = ax.quiver( data.world_t_body[:, 0], data.world_t_body[:, 1], data.world_t_body[:, 2], data.meas_vectors[:, landmark_inx, 0], data.meas_vectors[:, landmark_inx, 1], data.meas_vectors[:, landmark_inx, 2], **dotted_style, ) if animated: ani = animation.FuncAnimation( fig, update_plot, len(values_per_iter), fargs=tuple(), interval=500 ) filename = "robot_3d_localization.mp4" ani.save(filename, dpi=200) print(f"Wrote to {filename}") else: # Add a slider for iterations at the bottom of the plot plt.subplots_adjust(bottom=0.2) ax_slider = plt.axes([0.25, 0.1, 0.65, 0.03]) iteration_slider = Slider( ax=ax_slider, label="Iteration", valmin=0, valmax=len(values_per_iter) - 1, valinit=len(values_per_iter) - 1, valfmt="%0.0f", ) iteration_slider.on_changed(update_plot) iteration_slider.set_val(len(values_per_iter) - 1) plt.show()
[docs]def get_data_to_plot(v: Values) -> AttrDict: """ Compute direct quantities needed for plotting. """ data = AttrDict() # Landmark positions data.world_t_landmark = np.array(v["world_t_landmark"]) # Pose positions data.world_t_body = np.array([p.position() for p in v["world_T_body"]]) # Pose axis vectors data.pose_x_axes = np.array([p.rotation() * np.array([1, 0, 0]) for p in v["world_T_body"]]) data.pose_y_axes = np.array([p.rotation() * np.array([0, 1, 0]) for p in v["world_T_body"]]) data.pose_z_axes = np.array([p.rotation() * np.array([0, 0, 1]) for p in v["world_T_body"]]) # Measurement heading vectors from each pose to each landmark data.meas_vectors = np.array( [ [ v["world_T_body"][i].rotation() * v["body_t_landmark_measurements"][i][j] for j in range(len(v["world_t_landmark"])) ] for i in range(len(v["world_T_body"])) ] ) return data