import numpy as np
import matplotlib.pyplot as plt

# nav.: http://xkcd.com/356/

sx,sy = 71,71              # grootte van het grid
p1    = (sx/2-1, sy/2)     # punt van 1 volt
p2    = (sx/2+1, sy/2+1)   # punt van 0 volt
Niter = 1000               # aantal iteraties

v = np.zeros((sx,sy)) # spannings veld
t = np.zeros((sx,sy)) # tijdelijk veld

v += 0.5  # even wat sneller...

def iteratie_slag():
    global v,t,sx,sy,p1,p2
    """
    Spanning op een punt zal uiteindelijk het gemiddelde van de buren worden.
    """
    # in het veld
    for x in range(1,sx-1):
        for y in range(1,sy-1):
            t[x,y] = (v[x-1,y] + v[x+1,y] + v[x,y-1] + v[x,y+1]) / 4.0
    # op de randen
    for x in range(1,sx-1):
        t[x,0]    = (v[x-1,0   ] + v[x,1   ] + v[x+1,0   ]) / 3.0
        t[x,sy-1] = (v[x-1,sy-1] + v[x,sy-2] + v[x+1,sy-1]) / 3.0
    for y in range(1,sy-1):
        t[0,y]    = (v[0   ,y-1] + v[1   ,y] + v[0   ,y+1]) / 3.0
        t[sx-1,y] = (v[sx-1,y-1] + v[sx-2,y] + v[sx-1,y+1]) / 3.0
    # op de hoeken
    t[0   ,0   ] = (v[1   ,0   ] + v[0   ,1   ]) / 2.0 
    t[sx-1,0   ] = (v[sx-2,0   ] + v[sx-1,1   ]) / 2.0 
    t[0   ,sy-1] = (v[1   ,sy-1] + v[0   ,sy-2]) / 2.0 
    t[sx-1,sy-1] = (v[sx-2,sy-1] + v[sx-1,sy-2]) / 2.0 
    # boundary condition
    t[p1] = 1.0
    t[p2] = 0.0
    # omwisselen
    v[:,:] = t[:,:]

# tbv, een snellere iterartie slag, wat hulp arrays voor fancy indexing
buren_x = np.zeros((sx, sy, 4), dtype=np.int)
buren_y = np.zeros((sx, sy, 4), dtype=np.int)
for x in range(sx):
    for y in range(sy):
        buren_x[x, y] = [x-1, x+1, x  , x  ]
        buren_y[x, y] = [y  , y  , y-1, y+1]
        # we gaan smokkelen als, een index buiten het veld valt nemen we 
        # het punt zelf.
        for b in range(4): # buren
            if buren_x[x, y, b] < 0    : buren_x[x, y, b] = 0
            if buren_x[x, y, b] > sx-1 : buren_x[x, y, b] = sx-1
            if buren_y[x, y, b] < 0    : buren_y[x, y, b] = 0
            if buren_y[x, y, b] > sy-1 : buren_y[x, y, b] = sy-1

def iteratie_slag2():
    global v,buren_x,buren_y,p1,p2
    """
    snellere versie, pure numpy
    """
    v = np.sum(v[buren_x, buren_y], axis=2) / 4.0
    # boundary condition
    v[p1] = 1.0
    v[p2] = 0.0
  

# itereren
r = []
for i in range(Niter):
    #iteratie_slag()
    iteratie_slag2()
    weerstand = 1.0 / (v[p1] - v[p1[0]+1,p1[1]  ] +
                       v[p1] - v[p1[0]-1,p1[1]  ] +
                       v[p1] - v[p1[0]  ,p1[1]+1] +
                       v[p1] - v[p1[0]  ,p1[1]-1]  )
    r.append(weerstand)
print weerstand

# plaatjes maken
plt.close('all')
plt.figure()
plt.subplot(122, aspect='equal')
plt.contourf(v.T, 100)
plt.grid()
plt.title('Potentiaal')

plt.subplot(121)
plt.plot(r,'s-', markevery=len(r)/10)
plt.grid()
plt.xlabel('Iteraties')
plt.ylabel('Weerstand [$\Omega$]')
plt.title('eind: %f $\Omega$' % weerstand)

plt.show()