正如我们之前所说...
下面有4个文件,sm.h、main.c、sm.c和gen.c。接口在 sm.h 中定义,main.c 显示如何使用接口,sm.c 是实现,gen.c 是为 main.c 创建测试输入文件的 C 程序。
我建议仔细研究 sm.h 和 main.c。谢谢。
接口定义在 sm.h
#ifndef __SM_H__
#define __SM_H__
typedef struct sm_entry_s sm_entry_s;
typedef struct sm_search_s sm_search_s;
typedef struct sm_hash_s sm_hash_s;
// creates a sparse matrix
sm_hash_s *sparse_matrix_create(int estimate) ;
// puts a value into the sparse matrix at row,col
// if an entry exists at row,col, value is added to the entry
// if an entry does not exist, it creates one and sets it to value
double sparse_matrix_put(sm_hash_s *sm, unsigned row, unsigned col, double value) ;
// return the current value at row,col. If no entry, returns 0.0
double sparse_matrix_get(sm_hash_s *sm, unsigned row, unsigned col) ;
// frees the memory used by the sparse matrix
void sparse_matrix_free(sm_hash_s *sm) ;
// print some stats to stdout
void sparse_matrix_stat(sm_hash_s *sm) ;
// dump the matrix to stdout ( ugly )
void sparse_matrix_dump(sm_hash_s *sm) ;
// return a pointer for use in sparse_entry ( see below )
sm_search_s *sparse_matrix_search(sm_hash_s *sm) ;
// gets the next entry in the sparse matrix search
// returns 0 when there are no more entries
int sparse_entry(sm_search_s *p, unsigned *prow, unsigned *pcol, double *pval) ;
#endif
如何使用该接口的示例在 main.c 中:
#include <stdio.h>
#include <stdlib.h>
#include "sm.h"
// expects lines of the form %u %u %lf ( row, col, value )
int main(int argc, char **argv) {
unsigned row, col;
double d;
sm_hash_s *sm = sparse_matrix_create(32);
while(scanf("%u %u %lf", &row, &col, &d) == 3) {
sparse_matrix_put(sm, row, col, d);
}
// print some statistics
sparse_matrix_stat(sm);
// dump the whole ugly thing out
// sparse_matrix_dump(sm);
// print a line for each entry
sm_search_s *s = sparse_matrix_search(sm);
while(sparse_entry(s, &row, &col, &d)) {
printf("%u %u %lf\n", row, col, d);
}
free(s);
sparse_matrix_free(sm);
return 0;
}
实现在 sm.c
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
#include "sm.h"
#define public /* nothing */
// change to 1 for debugging
#define debug if(0) printf
struct sm_entry_s {
sm_entry_s *next;
unsigned row;
unsigned col;
unsigned index;
double value;
};
struct sm_hash_s {
sm_entry_s **buckets;
unsigned nentry;
unsigned nbucket; /* hash table bucket count */
unsigned tableMask; /* hash table mask */
unsigned tableCapacity; /* hash table capacity */
unsigned alloced;
unsigned freeed;
};
struct sm_search_s {
sm_hash_s *sm;
int i;
sm_entry_s *next;
};
static void *sm_alloc(sm_hash_s *sm, unsigned n) {
void *p = calloc(1, n);
if(p == 0) {
fprintf(stderr, "calloc(%d) failed\n", n);
exit(1);
}
if(sm) {
sm->alloced += n;
}
return p;
}
static void sm_free(sm_hash_s *sm, void *f, unsigned n) {
sm->freeed += n;
free(f);
}
/// Round up to next higher power of 2 (return x if it's already a power
/// of 2).
static int pow2roundup (int x) {
assert(sizeof(x) == 4);
if (x < 0)
return 0;
--x;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return x+1;
}
static void sm_init(sm_hash_s *sm) {
sm->buckets = sm_alloc(sm, sm->nbucket * sizeof(*sm->buckets));
sm->tableMask = sm->nbucket - 1;
sm->tableCapacity = 14*sm->nbucket/16;
}
static unsigned int hash_row(unsigned int x) {
x = ((x >> 16) ^ x) * 0x45d9f3b;
x = ((x >> 16) ^ x) * 0x45d9f3b;
x = ((x >> 16) ^ x);
return x;
}
static unsigned int hash_col(unsigned int x) {
x = ((x >> 16) ^ x) * 0x3335b369;
x = ((x >> 16) ^ x) * 0x3335b369;
x = ((x >> 16) ^ x);
return x;
}
static unsigned int hash_index(sm_hash_s *p, unsigned row, unsigned col) {
unsigned h = hash_row(row) ^ hash_col(col);
//printf("h = %x\n", h);
return h & p->tableMask;
}
static void rehash(sm_hash_s *sm) {
sm_entry_s **oldtable = sm->buckets;
int oldsize = sm->nbucket;
sm_entry_s *ent, *newent;
sm->nbucket *= 2;
debug("rehash new size = %u\n", sm->nbucket);
sm_init(sm);
for(int i=0; i<oldsize; i++) {
for (ent=oldtable[i]; ent; ent=newent) {
newent = ent->next;
ent->next = sm->buckets[ent->index & sm->tableMask];
sm->buckets[ent->index & sm->tableMask] = ent;
}
}
sm_free(sm, oldtable, oldsize * sizeof(sm_entry_s *));
}
public sm_hash_s *sparse_matrix_create(int nbucket) {
assert(sizeof(unsigned) == 4);
sm_hash_s *p = (sm_hash_s *)sm_alloc(0, sizeof(*p));
p->alloced = sizeof(*p);
p->nbucket = pow2roundup(nbucket);
sm_init(p);
return p;
}
static sm_entry_s *sm_find(sm_hash_s *sm, unsigned index, unsigned row, unsigned col) {
for(sm_entry_s *p = sm->buckets[index]; p; p = p->next) {
if(p->row == row && p->col == col) {
return p;
}
}
return 0;
}
public double sparse_matrix_put(sm_hash_s *sm, unsigned row, unsigned col, double value) {
unsigned index = hash_index(sm, row, col);
sm_entry_s *p = sm_find(sm, index, row, col);
if(p == 0) {
p = sm_alloc(sm, sizeof(*p));
p->row = row;
p->col = col;
p->value = 0;
p->index = index;
p->next = sm->buckets[index];
sm->buckets[index] = p;
if(sm->nentry++ > sm->tableCapacity) {
rehash(sm);
}
}
p->value += value;
return p->value;
}
public double sparse_matrix_get(sm_hash_s *sm, unsigned row, unsigned col) {
sm_entry_s *p = sm_find(sm, hash_index(sm, row, col), row, col);
return p ? p->value : 0.0;
}
public void sparse_matrix_free(sm_hash_s *sm) {
sm_entry_s **pp = sm->buckets;
for(int i = 0; i < sm->nbucket; i++, pp++) {
while(*pp) {
sm_entry_s *next = (*pp)->next;
sm_free(sm, *pp, sizeof(*pp));
*pp = next;
}
}
free(sm->buckets);
free(sm);
}
public void sparse_matrix_stat(sm_hash_s *sm) {
unsigned count = 0;
unsigned max = 0;
for(int i = 0; i < sm->nbucket; i++) {
int n = 0;
for(sm_entry_s *p = sm->buckets[i]; p; p = p->next) {
n++;
}
count += n;
if(n > max) max = n;
}
unsigned avg = count / sm->nbucket;
printf("%u alloc, %u free, %u in use\n",
sm->alloced, sm->freeed, sm->alloced - sm->freeed);
printf("%u buckets, %u entries, %u max, %u avg\n",
sm->nbucket, count, max, avg);
}
public void sparse_matrix_dump(sm_hash_s *sm) {
for(int i = 0; i < sm->nbucket; i++) {
sm_entry_s *p = sm->buckets[i];
if(p) {
printf("[%u] ", i);
for( ; p; p = p->next) {
printf(" [%u %u %lf]", p->row, p->col, p->value);
printf(" %p", p);
}
printf("\n");
}
}
}
static int search_next(sm_search_s *p, int i) {
//printf("start %d\n", i);
for( ; i < p->sm->nbucket; i++) {
if(p->sm->buckets[i]) {
p->i = i;
p->next = p->sm->buckets[i];
//printf("next is in %d %p\n", i, p->next);
return 1;
}
}
//printf("no more\n");
p->next = 0;
return 0;
}
public sm_search_s *sparse_matrix_search(sm_hash_s *sm) {
sm_search_s *p = malloc(sizeof(*p));
if(p == 0) {
fprintf(stderr, "malloc(%ld) failed\n", sizeof(*p));
exit(1);
}
p->sm = sm;
p->i = 0;
p->next = 0;
search_next(p, 0);
return p;
}
public int sparse_entry(sm_search_s *p, unsigned *prow, unsigned *pcol, double *pval) {
if(p->next) {
sm_entry_s *e = p->next;
*prow = e->row;
*pcol = e->col;
*pval = e->value;
if(e->next) {
p->next = e->next;
} else {
search_next(p, p->i + 1);
}
return 1;
}
return 0;
}
基因c
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <assert.h>
#include <ctype.h>
unsigned get_rand(unsigned min, unsigned max) {
unsigned u = min + (random() % (int)(max - min + 1));
assert(u >= min);
assert(u <= max);
return u;
}
int main(int argc, char **argv) {
if(argc != 3 || !isdigit(*argv[1]) || !isdigit(*argv[2])) {
fprintf(stderr, "usage: %s %%u %%u\n", argv[0]);
exit(0);
}
unsigned long N = strtoul(argv[1], 0, 0);
unsigned long C = strtoul(argv[2], 0, 0);
srandom(time(0));
for(int i = 0; i < C; i++) {
unsigned long row = get_rand(0, N-1);
unsigned long col = get_rand(0, N-1);
printf("%lu %lu 1.0\n", row, col);
}
return 0;
}