]> git.sesse.net Git - wloh/blob - train.pl
Fix include paths for static scripts.
[wloh] / train.pl
1 #! /usr/bin/perl
2 use DBI;
3 use strict;
4 use warnings;
5 no warnings qw(once);
6 use POSIX;
7 use lib qw( include );
8 require 'config.pm';
9
10 # Find last completely done season 
11 sub find_last_season {
12         my ($dbh, $locale) = @_;
13         my $ref = $dbh->selectrow_hashref('SELECT sesong FROM fotballserier se JOIN fotballspraak sp ON se.spraak=sp.id GROUP BY kultur,sesong HAVING COUNT(*)=COUNT(avgjort=1 OR NULL) AND kultur=? ORDER BY kultur,sesong DESC LIMIT 1', undef, $locale);
14         return $ref->{'sesong'};
15 }
16
17 sub fetch_games {
18         my ($dbh, $locale, $last_season, $games, $ids) = @_;
19         my $q = $dbh->prepare('
20 SELECT
21   deltager1.id as p1, deltager2.id as p2, maalfor, maalmot, least(pow(2.0, (sesong - ? + 3) / 3.0), 1.0) AS vekt
22 FROM
23   ( SELECT * FROM fotballresultater UNION ALL SELECT * FROM fotballresultater_2123 ) resultater
24   JOIN Fotballdeltagere deltager1 ON resultater.Lagrecno=deltager1.Nr AND resultater.Serie=deltager1.Serie
25   JOIN Fotballdeltagere deltager2 ON resultater.Motstander=deltager2.Nr AND resultater.Serie=deltager2.Serie
26   JOIN Fotballserier serier ON resultater.Serie=serier.Nr
27   JOIN Fotballspraak spraak ON serier.Spraak=spraak.Id
28 WHERE deltager1.Nr > deltager2.nr AND kultur=?
29         ');
30         $q->execute($last_season, $locale);
31
32         while (my $ref = $q->fetchrow_hashref) {
33                 next if ($ref->{'maalfor'} == 150 && $ref->{'maalmot'} == 0);
34                 next if ($ref->{'maalfor'} == 0 && $ref->{'maalmot'} == 150);
35                 next if ($ref->{'maalfor'} == 150 && $ref->{'maalmot'} == 150);
36                 next if ($ref->{'maalfor'} == 0 && $ref->{'maalmot'} == 0);
37                 push @$games, { %$ref };
38                 $ids->{$ref->{'p1'}} = 1;
39                 $ids->{$ref->{'p2'}} = 1;
40         }
41 }
42
43 sub output_to_file {
44         my ($games, $ids) = @_;
45
46         my $tmpnam = POSIX::tmpnam();
47         open DATA, ">", $tmpnam
48                 or die "$tmpnam: $!";
49
50         printf DATA "%d\n", scalar keys %$ids;
51         for my $id (keys %$ids) {
52                 printf DATA "%d\n", $id;
53         }
54         for my $ref (@$games) {
55                 printf DATA "%d %d %d %d %f\n", $ref->{'p1'}, $ref->{'p2'}, $ref->{'maalfor'}, $ref->{'maalmot'}, $ref->{'vekt'};
56         }
57         close DATA;
58
59         return $tmpnam;
60 }
61
62 sub train_model {
63         my ($filename, $locale, $ratings, $covariances, $aux_params) = @_;
64
65         open RATINGS, "$config::base_dir/bayeswf < $filename |"
66                 or die "bayeswf: $!";
67         while (<RATINGS>) {
68                 chomp;
69                 my @x = split;
70                 if ($x[0] eq 'covariance') {
71                         push @$covariances, (join("\t", @x[1..3]));
72                 } elsif ($x[0] eq 'aux_param') {
73                         push @$aux_params, ($locale .  "\t" . $x[1] . "\t" . $x[2]);
74                 } else {
75                         push @$ratings, ($x[2] . "\t" . $x[0] . "\t" . $x[1]);
76                 }
77         }
78
79         close RATINGS;
80 }
81
82 sub find_all_locales {
83         my $dbh = shift;
84         my $q = $dbh->prepare('SELECT kultur FROM fotballspraak WHERE nyestesesong<>-1');
85         $q->execute;
86
87         my @locales = ();
88         while (my $ref = $q->fetchrow_hashref) {
89                 push @locales, $ref->{'kultur'};
90         }
91
92         return @locales;
93 }
94
95 my $dbh = DBI->connect($config::local_connstr, $config::local_username, $config::local_password)
96         or die "connect: " . $DBI::errstr;
97 $dbh->{AutoCommit} = 0;
98 $dbh->{RaiseError} = 1;
99
100 $dbh->do('SET client_min_messages TO WARNING');
101
102 my @locales = find_all_locales($dbh);
103
104 my @ratings = ();
105 my @covariances = ();
106 my @aux_params = ();
107
108 for my $locale (@locales) {
109         my $last_season = find_last_season($dbh, $locale);
110         my @games = ();
111         my %ids = ();
112         fetch_games($dbh, $locale, $last_season, \@games, \%ids);
113         my $tmpnam = output_to_file(\@games, \%ids);
114
115         train_model($tmpnam, $locale, \@ratings, \@covariances, \@aux_params);
116         unlink($tmpnam);
117 }
118
119 $dbh->do('CREATE TABLE new_covariance ( player1 smallint NOT NULL, player2 smallint NOT NULL, cov float NOT NULL )');
120 $dbh->do('COPY new_covariance ( player1, player2, cov ) FROM STDIN');
121 $dbh->pg_putcopydata(join("\n", @covariances));
122 $dbh->pg_putcopyend();
123 $dbh->do('ALTER TABLE new_covariance ADD PRIMARY KEY ( player1, player2 );');
124 $dbh->do('DROP TABLE IF EXISTS covariance');
125 $dbh->do('ALTER TABLE new_covariance RENAME TO covariance');
126
127 $dbh->do('TRUNCATE aux_params');
128 $dbh->do('COPY aux_params ( kultur, id, value ) FROM STDIN');
129 $dbh->pg_putcopydata(join("\n", @aux_params));
130 $dbh->pg_putcopyend();
131
132 $dbh->do('TRUNCATE ratings');
133 $dbh->do('COPY ratings ( id, rating, rating_stddev ) FROM STDIN');
134 $dbh->pg_putcopydata(join("\n", @ratings));
135 $dbh->pg_putcopyend();
136
137 $dbh->commit;