linalg / app.py
buoyrina
matrix-vector
99f0fa7
raw
history blame
3.24 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
def matrix_vector_multiplication_visualization(matrix, vector):
try:
# Parse inputs
matrix = np.array([[float(x) for x in row.split(",")] for row in matrix.split(";")])
vector = np.array([float(x) for x in vector.split(",")])
# Ensure the matrix is 2x2 and the vector is 2D
if matrix.shape != (2, 2):
return "Error: Matrix must be 2x2.", None
if vector.shape != (2,):
return "Error: Vector must be 2D.", None
# Perform matrix-vector multiplication
transformed_vector = np.dot(matrix, vector)
# Create a grid for visualization
x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)
X, Y = np.meshgrid(x, y)
grid = np.vstack([X.flatten(), Y.flatten()])
transformed_grid = np.dot(matrix, grid).reshape(2, -1, 10)
# Create the plot
fig, ax = plt.subplots(figsize=(6, 6))
# Plot the grid before and after transformation
for i in range(grid.shape[1]):
ax.plot([grid[0, i], transformed_grid[0, i]], [grid[1, i], transformed_grid[1, i]],
color="gray", linewidth=0.5, alpha=0.7)
# Plot the original vector
ax.quiver(0, 0, vector[0], vector[1], angles="xy", scale_units="xy", scale=1, color="red", label="Original Vector")
# Plot the transformed vector
ax.quiver(0, 0, transformed_vector[0], transformed_vector[1], angles="xy", scale_units="xy", scale=1, color="blue", label="Transformed Vector")
# Plot settings
ax.axhline(0, color='black', linewidth=0.5)
ax.axvline(0, color='black', linewidth=0.5)
ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)
ax.set_aspect('equal')
ax.grid(True)
ax.legend()
ax.set_title("Matrix-Vector Multiplication Visualization")
# Save the plot to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
plt.close(fig)
return f"Transformed Vector: {transformed_vector.tolist()}", Image.open(buf)
except Exception as e:
return f"Error: {str(e)}", None
# Create the Gradio app
with gr.Blocks() as app:
gr.Markdown("## Matrix-Vector Multiplication Visualization")
gr.Markdown("""
- Enter a **2x2 matrix** as `a,b;c,d` (rows separated by semicolons).
- Enter a **2D vector** as `x,y`.
- See the original vector (red), transformed vector (blue), and grid transformation.
""")
with gr.Row():
matrix_input = gr.Textbox(label="Matrix (2x2, e.g., 1,0;0,1)", placeholder="e.g., 1,0;0,1")
vector_input = gr.Textbox(label="Vector (2D, e.g., 1,1)", placeholder="e.g., 1,1")
output_text = gr.Textbox(label="Result")
output_image = gr.Image(label="Visualization")
calculate_button = gr.Button("Visualize")
calculate_button.click(
fn=matrix_vector_multiplication_visualization,
inputs=[matrix_input, vector_input],
outputs=[output_text, output_image]
)
app.launch()