import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse.linalg import spsolve


# nav.: http://xkcd.com/356/
# versie 2, met np.linalg.solve, we stoppen het probleem in Ax=b vergelijking
# versie 3, we gaan sparse matrices gebruiken van scipy.sparse

sx,sy = 71, 71             # grootte van het grid
p1    = (sx/2-1, sy/2)     # punt van 1 volt
p2    = (sx/2+1, sy/2+1)   # punt van 0 volt


N = sx * sy # dus... hoofdletter n is sx maal sy...

# array's waar we later een matrix van maken
A_row  = []
A_col  = []
A_data = []
b_row  = []
b_col  = []
b_data = []


# wat hulp functies...
def V(x,y):
    '''Maps x,y to index for A and b'''
    return y * sy + x
def W(i):
    '''Maps index for A and b to x,y'''  # wordt niet gebruikt denk ik
    return i % sy, i / sy
def buren(x,y):
    '''Geeft een lijst met tuples van 'valid' buren terug'''
    retval = []
    if (x > 0   ): retval.append((x-1,y  ))
    if (x < sx-1): retval.append((x+1,y  ))
    if (y > 0   ): retval.append((x  ,y-1))
    if (y < sy-1): retval.append((x  ,y+1))
    return retval

# Matrix A en kolom b vullen, 
for x, y in product(range(sx), range(sy)): # itereren over alle punten

    if (x, y) == p1:     # 1 Volt punt
    
        A_row.append(V(x,y))
        A_col.append(V(x,y))
        A_data.append(1.0)
        b_row.append(V(x,y)) 
        b_col.append(0)
        b_data.append(1.0) 
        
    elif (x, y) == p2:    # 0 Volt punt
    
        A_row.append(V(x,y))
        A_col.append(V(x,y))
        A_data.append(1.0)
        
    else:               # de andere 'gewone' punten inclusief randen en hoeken
    
        lb = buren(x, y) # lijst van buren
        A_row.append(V(x,y))
        A_col.append(V(x,y))
        A_data.append(float(len(lb)))
        for buur in lb:
            A_row.append(V(x,y))
            A_col.append(V(*buur))
            A_data.append(-1.0)


# Matrices maken
A = coo_matrix((A_data, (A_row, A_col)), shape=(N,N)).tocsr()
b = coo_matrix((b_data, (b_row, b_col)), shape=(N,1)).tocsr()

# Klaar is kees, oh nee nog even oplossen...
v = spsolve(A,b)

# ... rechtzetten ....
v = v.reshape(sx, sy).T

# weerstand uitrekenen
weerstand = 1.0 / (v[p1] - v[p1[0]+1,p1[1]  ] +
                   v[p1] - v[p1[0]-1,p1[1]  ] +
                   v[p1] - v[p1[0]  ,p1[1]+1] +
                   v[p1] - v[p1[0]  ,p1[1]-1]  )
print weerstand

# plaatjes maken
plt.close('all')
plt.figure()
plt.subplot(111, aspect='equal')
plt.contourf(v.T, 100)
plt.grid()
plt.title('Weerstand: ' + str(weerstand) + ' $\Omega$')
plt.show()