166 Commits

Author SHA1 Message Date
e0451d026c Update index. 2020-12-28 22:32:24 -08:00
1f1345477f Use Hugo's plaintext instead of file path for Stork index. 2020-12-28 22:32:17 -08:00
44529e872f Fix wrong path name for index file. 2020-12-28 22:23:29 -08:00
a10996954e Redirect search index. 2020-12-28 22:22:37 -08:00
4d1dfb5f66 Generate initial index. This will not be static indefinitely; I just need to find a way to build it in Nix. 2020-12-28 22:22:19 -08:00
f97b624688 Tweak search styles a little bit. 2020-12-28 22:03:04 -08:00
8215c59122 Change search highlight color. 2020-12-27 22:24:43 -08:00
eb97bd9c3e Add search box to main page. 2020-12-27 20:09:05 -08:00
d2e100fe4b Add search CSS. 2020-12-27 20:08:40 -08:00
de09a1f6bd Enable TOML output. 2020-12-27 20:08:27 -08:00
c40672e762 Add a way to generate TOML template for Stork. 2020-12-27 20:08:18 -08:00
565d4a6955 Update resume. 2020-12-14 18:03:31 -08:00
8f0f2eb35e Finish up the Coq Advent of Code post. 2020-12-02 18:45:28 -08:00
234b795157 Add Coq advent of code post. 2020-12-02 01:14:32 -08:00
e317c56c99 Add some shortcodes for making the game theory post nicer. 2020-11-08 21:22:51 -08:00
29d12a9914 Publish new Idris post. 2020-11-02 01:08:41 -08:00
b459e9cbfe Update typesafe imperative language post draft. 2020-11-01 23:56:55 -08:00
52abe73ef7 Make the typesafe imperative language work properly. 2020-10-31 01:34:23 -07:00
f0fe481bcf Add post about the typesafe imperative language. 2020-10-30 19:07:30 -07:00
222446a937 Add non-color indication to highlighted lines. 2020-10-10 17:12:40 -07:00
e7edd43034 Add draft warning. 2020-09-27 16:22:29 -07:00
2bc2c282e1 Revert "Experimentally enable shortcodes"
This reverts commit 5cc92d3a9d.
2020-09-27 14:47:25 -07:00
5cc92d3a9d Experimentally enable shortcodes 2020-09-27 14:42:35 -07:00
4be8a25699 Add a label to codelines that includes the source file. 2020-09-27 14:41:56 -07:00
d3421733e1 Update resume. 2020-09-25 22:52:07 -07:00
4c099a54e8 Publish part 13. 2020-09-19 16:27:41 -07:00
9f77f07ed2 Finish 13th part of the compiler series. 2020-09-19 16:14:07 -07:00
04ab1a137c Mark 13th post as draft 2020-09-19 11:59:54 -07:00
53744ac772 Fix wording 2020-09-18 15:14:34 -07:00
50a1c33adb Adjust code lines. 2020-09-18 14:42:50 -07:00
d153af5212 Get rid of more constructors and make mangled names optional. 2020-09-18 14:09:03 -07:00
a336b27b6c Remove unneeded explicit calls to std::string 2020-09-18 12:27:57 -07:00
97eb4b6e3e Fix silent error in set_mangled_name 2020-09-18 12:02:37 -07:00
430768eac5 Add a TODO to part 13. 2020-09-17 22:56:08 -07:00
5db864881a Fix use of wrong environment for name mangling. 2020-09-17 22:55:27 -07:00
d3b1047d37 Renamed the file since we have no optimization. 2020-09-17 22:36:43 -07:00
98cac103c4 Update blog post, switching away from two sections. 2020-09-17 22:35:40 -07:00
7226d66f67 Remove the parent method from type_env. 2020-09-17 22:35:12 -07:00
8a352ed3ea Roll back optimization changes. 2020-09-17 20:45:24 -07:00
02f8306c7b Use an instruction instead of a special-case boolean instruction. 2020-09-17 18:33:52 -07:00
cf6f353f20 Change tagging to assume sign extension.
ARM and x86_64 require "real" pointers to be
sign-extended in their top bits. This means
a working pointer is guaranteed to have either "11"
as leading bits, or "00". So, to tag a "fake" pointer
which is an unboxed 32-bit integer, we simply toggle
the leading bit.
2020-09-17 18:30:55 -07:00
7a631b3557 Make a few more things classes. 2020-09-17 18:30:41 -07:00
5e13047846 Make global scope a class. 2020-09-15 19:45:05 -07:00
c17d532802 Make type_mgr a class. 2020-09-15 19:19:58 -07:00
55e4e61906 Make mangler a class and reformat graph. 2020-09-15 19:13:48 -07:00
f2f88ab9ca Make env a class. 2020-09-15 19:12:12 -07:00
ba418d357f Make type_env a class. 2020-09-15 19:10:36 -07:00
0e3f16139d Make llvm_context a class. 2020-09-15 19:08:00 -07:00
55486d511f Make some refactors for name mangling and encapsulation. 2020-09-15 18:51:28 -07:00
6080094c41 Require mangled names for global variables. 2020-09-15 14:39:31 -07:00
6b8d3b0f8a Refactor errors and update post draft. 2020-09-11 21:29:49 -07:00
725958137a Factor type into case strategy constructor. 2020-09-11 13:03:00 -07:00
1f6b4bef74 Start working on part 13 of compiler series. 2020-09-11 02:16:57 -07:00
fe1e0a6de0 Switch to using FILE* and default YY_INPUT. 2020-09-11 02:16:29 -07:00
1f3c42fc44 Change constructor visibility to global.
Constructors are always effectively global.
2020-09-10 20:11:55 -07:00
8bf67c7dc3 Merge branch 'master' of https://dev.danilafe.com/Web-Projects/blog-static into master 2020-09-10 18:47:55 -07:00
13214cee96 Try out unboxing integers. 2020-09-10 17:32:16 -07:00
579c7bad92 Enable more syntax. 2020-09-10 16:04:44 -07:00
f00a6a7783 Actually use the environment for binop functions. 2020-09-10 16:03:56 -07:00
2a81fdd9fb Stop using mangled names for local variables. 2020-09-10 15:14:19 -07:00
17c59e595c Add assertion regarding local name mangling. 2020-09-10 15:05:02 -07:00
ad2576eae2 Move common code into loops. 2020-09-10 14:50:03 -07:00
72d8179cc5 Add compile-time flag to disable output. 2020-09-10 14:07:28 -07:00
dbabec0db6 Tweak parsed type error warning. 2020-09-10 14:04:06 -07:00
76675fbc9b Make make_case_for throw from the second time on.
Also clean up the errors thrown a little bit.
2020-09-10 14:03:04 -07:00
ca395b5c09 Add programs to trigger error cases. 2020-09-10 14:02:19 -07:00
1a05d5ff7a Add type errors to identifier nodes. 2020-09-10 12:59:26 -07:00
56f0dbd02f Prevent case compilation from crashing and burning. 2020-09-10 12:53:55 -07:00
9fc0ff961d Add more built-in boolean-specific instructions. 2020-09-10 12:44:41 -07:00
73441dc93b Register booleans as internal types. 2020-09-10 00:54:35 -07:00
df5f5eba1c Make sure to delete LLVM target machine. 2020-09-09 23:45:48 -07:00
d950b8dc90 Initialize graph indegree. 2020-09-09 23:44:53 -07:00
85394b185d Add prototype impl of case specialization.
Boolean cases could be translated to ifs, and
integer cases to jumps. That's still in progress.
2020-09-09 22:49:35 -07:00
86b49f9cc3 Add 'internal' types. 2020-09-09 18:08:38 -07:00
9769b3e396 Replace throw 0 with real exceptions or assertions. 2020-09-09 17:19:23 -07:00
e337992410 Add sources for unification type errors. 2020-09-09 15:26:18 -07:00
d5c3a44041 Add extra line after code fence. 2020-09-09 15:25:48 -07:00
eade42be49 Print locations in non-unification type errors. 2020-09-09 15:15:25 -07:00
d0fac50cfd Add locations to patterns. 2020-09-09 15:15:09 -07:00
dd4aa6fb9d Require C++17 for optionals 2020-09-09 15:14:37 -07:00
aa867b2e5f Add locations to error reporting. 2020-09-09 15:08:43 -07:00
2fa2be4b9e Add a method to print location. 2020-09-09 14:41:16 -07:00
d5536467f6 Touch up source index code. 2020-09-09 14:20:10 -07:00
67cb61c93f Keep track of locations in definitions. 2020-09-09 14:19:46 -07:00
578d580683 Make driver keep track of line numbers and locations. 2020-09-09 13:57:01 -07:00
789f277780 Update ASTs to actually take in locations.
Didn't realize I broke the build by leaving this out.
2020-09-09 13:29:28 -07:00
308ec615b9 Start using driver, and switch to file IO. 2020-09-09 13:28:43 -07:00
0e40c9e216 Enable locations. 2020-09-09 12:21:50 -07:00
5dbf75b5e4 Fork off version 13 of the compiler. 2020-09-08 18:38:05 -07:00
b921ddfc8d Update resume. 2020-09-02 13:47:55 -07:00
bf3c81fe24 Fix invalid property for flexbox. 2020-08-29 00:08:16 -07:00
06cbd93f05 Publish boolean values post. 2020-08-21 23:06:26 -07:00
6c3780d9ea Finish up the draft of the boolean values post. 2020-08-21 17:37:22 -07:00
6f0667bb28 Add draft of boolean values post. 2020-08-20 21:19:47 -07:00
8368283a3e Add warning about evaluation model. 2020-08-15 01:37:57 -07:00
18ee3a1526 Add margins to code tables. 2020-08-15 01:18:01 -07:00
b0e501f086 Publish the new typesafe interpreter post. 2020-08-12 15:48:53 -07:00
385ae59133 Merge branch 'colors' into master 2020-08-12 15:43:42 -07:00
49469bdf12 Fix issues in typesafe interpreter article. 2020-08-12 15:43:22 -07:00
020417e971 Add draft of new Idris typechecking post.
This one uses line highlights!
2020-08-12 01:38:38 -07:00
eff0de5330 Allow the codelines shortcode to use hl_lines. 2020-08-12 01:37:55 -07:00
b219f6855e Change highlight color for code. 2020-08-12 01:37:39 -07:00
068d0218b0 Fix typesafe interpreter post. 2020-08-11 19:54:45 -07:00
65215ccdd6 Start working on improving color handling in code. 2020-08-11 19:29:55 -07:00
3e9f6a14f2 Fix single-line scroll bug 2020-08-11 17:43:59 -07:00
7623787b1c Mention Kai's help in time traveling article. 2020-07-30 02:05:43 -07:00
e15daa8f6d Make the detailed time traveling example a subsection. 2020-07-30 01:09:30 -07:00
298cf6599c Publish time traveling post. 2020-07-30 00:58:48 -07:00
841930a8ef Add time traveling code. 2020-07-30 00:57:47 -07:00
9b37e496cb Add figure size classes to global CSS. 2020-07-30 00:57:27 -07:00
58e6ad9e79 Update lazy evaluation post with images and more. 2020-07-30 00:49:35 -07:00
3aa2a6783e Add images to time traveling post. 2020-07-29 20:09:32 -07:00
d64a0d1fcd Add version of typesafe interpreter with tuples. 2020-07-23 16:38:54 -07:00
ba141031dd Remove the tweet shortcode. 2020-07-23 13:50:09 -07:00
ebdc63f5a0 Make small edit to DELL post. 2020-07-23 13:45:24 -07:00
5af0a09714 Publish DELL post. 2020-07-23 13:41:33 -07:00
8a2bc2660c Update date on typesafe interpreter. 2020-07-22 14:38:01 -07:00
e59b8cf403 Edit and publish typesafe interpreter. 2020-07-22 14:35:19 -07:00
b078ef9a22 Remove implicit arguments from TypsafeIntrV2. 2020-07-22 14:30:47 -07:00
fdaec6d5a9 Make small adjustments to backend math post. 2020-07-21 15:34:46 -07:00
b631346379 Publish the mathematics post. 2020-07-21 14:55:52 -07:00
e9f2378b47 Resume working on the draft of time traveling. 2020-07-20 22:32:14 -07:00
7d2f78d25c Add links and make small clarifications. 2020-07-20 13:56:07 -07:00
1f734a613c Add the second part of the typechecking post. 2020-07-19 22:56:44 -07:00
a3c299b057 Start working on the improved type-safe interpreter. 2020-07-19 17:16:31 -07:00
12aedfce92 Make small fixes to math rendering code. 2020-07-19 14:09:24 -07:00
65645346a2 Adjust title in DELL post. 2020-07-18 20:47:38 -07:00
cb65e89e53 Add math rendering draft. 2020-07-18 20:47:16 -07:00
6a2fec8ef4 Update the about page. 2020-07-17 19:39:43 -07:00
aa59c90810 Add the draft of the DELL post. 2020-07-17 19:39:35 -07:00
2b317930a0 Add resume link. 2020-07-15 15:09:37 -07:00
e7d56dd4bd Clean up some styles. 2020-07-15 13:56:03 -07:00
a4fedb276d Adjust margin spacing. 2020-07-15 13:18:34 -07:00
277c0a2ce6 Rework sidenote spacing and TOC. 2020-07-15 13:13:47 -07:00
ef3c61e9e6 Make table of contents dark. 2020-06-30 22:15:22 -07:00
1908126607 Add border to code. 2020-06-30 21:31:16 -07:00
2d77f8489f Move hiding code into margin SCSS. 2020-06-30 21:22:19 -07:00
0371651fdd Fix headings on Starbound post. 2020-06-24 23:01:35 -07:00
01734d24f7 Get started on tables of contents. 2020-06-24 22:46:22 -07:00
71fc0546e0 Move move code into common 'margin node' mixin. 2020-06-24 22:06:08 -07:00
871a745702 Extract margin variables and mixins into separate file. 2020-06-24 14:21:56 -07:00
3f0df8ae0d Add links for 12th part of compiler series. 2020-06-21 22:21:43 -07:00
1746011c16 Publish 12th part of compiler series. 2020-06-21 00:51:04 -07:00
7c4cfbf3d4 Fix typechecking of mutually recursive functions. 2020-06-21 00:47:26 -07:00
8524e098a8 Make proofreading-based fixes. 2020-06-20 23:50:26 -07:00
971f58da9b Finish draft of part 12 of compiler series. 2020-06-20 22:03:57 -07:00
c496be1031 Finish implementation description in part 12. 2020-06-20 20:46:54 -07:00
21851e3a9c Add more content to part 12. 2020-06-19 02:22:08 -07:00
600d5b91ea Remove unneeded parent class. 2020-06-18 23:06:13 -07:00
09b90c3bbc Add line numbers to codelines shortode. 2020-06-18 22:30:01 -07:00
f6ca13d6dc Add more implementation content to part 12. 2020-06-18 22:29:38 -07:00
9c4d7c514f Add more content to post 12 draft. 2020-06-16 23:32:09 -07:00
ad1946e9fb Add first draft of lambdas. 2020-06-14 02:00:20 -07:00
68910458e8 Properly handle null types in pattern typechecking. 2020-06-14 00:43:39 -07:00
240e87eca4 Use mangled names in variable environments. 2020-06-13 23:43:52 -07:00
6b5f7e25b7 Maybe finish the let/in code? 2020-06-01 00:23:41 -07:00
e7229e644f Start working on translation. 2020-05-31 18:52:52 -07:00
08c8aca144 Start working on a lifted version of a definition. 2020-05-31 14:37:33 -07:00
7f8dae74ac Adjust type output. 2020-05-31 00:50:58 -07:00
08503116ff Mark some definitions as global, so as not to capture them. 2020-05-31 00:34:12 -07:00
a1d679a59d No longer destroy the list of free variables.
It so happens that this list will tell us which variables
need to be captured.
2020-05-30 23:29:36 -07:00
4586bd0188 Check for free variables in the environment before generalizing. 2020-05-30 16:40:27 -07:00
a97b50f497 Add parsing of let/in. 2020-05-28 14:44:12 -07:00
c84ff11d0d Add typechecking to let/in expressions. 2020-05-26 00:52:54 -07:00
e966e74487 Extract ordering functionality into definition group. 2020-05-25 23:58:56 -07:00
3865abfb4d Add a struct to contain groups of mutually recursive definitions. 2020-05-25 22:11:45 -07:00
160 changed files with 10670 additions and 397 deletions

View File

@@ -0,0 +1,11 @@
@import "variables.scss";
@import "mixins.scss";
.assumption-number {
font-weight: bold;
}
.assumption {
@include bordered-block;
padding: 0.8rem;
}

102
code/aoc-coq/day1.v Normal file
View File

@@ -0,0 +1,102 @@
Require Import Coq.Lists.List.
Require Import Omega.
Definition has_pair (t : nat) (is : list nat) : Prop :=
exists n1 n2 : nat, n1 <> n2 /\ In n1 is /\ In n2 is /\ n1 + n2 = t.
Fixpoint find_matching (is : list nat) (total : nat) (x : nat) : option nat :=
match is with
| nil => None
| cons y ys =>
if Nat.eqb (x + y) total
then Some y
else find_matching ys total x
end.
Fixpoint find_sum (is : list nat) (total : nat) : option (nat * nat) :=
match is with
| nil => None
| cons x xs =>
match find_matching xs total x with
| None => find_sum xs total (* Was buggy! *)
| Some y => Some (x, y)
end
end.
Lemma find_matching_correct : forall is k x y,
find_matching is k x = Some y -> x + y = k.
Proof.
intros is. induction is;
intros k x y Hev.
- simpl in Hev. inversion Hev.
- simpl in Hev. destruct (Nat.eqb (x+a) k) eqn:Heq.
+ injection Hev as H; subst.
apply EqNat.beq_nat_eq. auto.
+ apply IHis. assumption.
Qed.
Lemma find_matching_skip : forall k x y i is,
find_matching is k x = Some y -> find_matching (cons i is) k x = Some y.
Proof.
intros k x y i is Hsmall.
simpl. destruct (Nat.eqb (x+i) k) eqn:Heq.
- apply find_matching_correct in Hsmall.
symmetry in Heq. apply EqNat.beq_nat_eq in Heq.
assert (i = y). { omega. } rewrite H. reflexivity.
- assumption.
Qed.
Lemma find_matching_works : forall is k x y, In y is /\ x + y = k ->
find_matching is k x = Some y.
Proof.
intros is. induction is;
intros k x y [Hin Heq].
- inversion Hin.
- inversion Hin.
+ subst a. simpl. Search Nat.eqb.
destruct (Nat.eqb_spec (x+y) k).
* reflexivity.
* exfalso. apply n. assumption.
+ apply find_matching_skip. apply IHis.
split; assumption.
Qed.
Theorem find_sum_works :
forall k is, has_pair k is ->
exists x y, (find_sum is k = Some (x, y) /\ x + y = k).
Proof.
intros k is. generalize dependent k.
induction is; intros k [x' [y' [Hneq [Hinx [Hiny Hsum]]]]].
- (* is is empty. But x is in is! *)
inversion Hinx.
- (* is is not empty. *)
inversion Hinx.
+ (* x is the first element. *)
subst a. inversion Hiny.
* (* y is also the first element; but this is impossible! *)
exfalso. apply Hneq. apply H.
* (* y is somewhere in the rest of the list.
We've proven that we will find it! *)
exists x'. simpl.
erewrite find_matching_works.
{ exists y'. split. reflexivity. assumption. }
{ split; assumption. }
+ (* x is not the first element. *)
inversion Hiny.
* (* y is the first element,
so x is somewhere in the rest of the list.
Again, we've proven that we can find it. *)
subst a. exists y'. simpl.
erewrite find_matching_works.
{ exists x'. split. reflexivity. rewrite plus_comm. assumption. }
{ split. assumption. rewrite plus_comm. assumption. }
* (* y is not the first element, either.
Of course, there could be another matching pair
starting with a. Otherwise, the inductive hypothesis applies. *)
simpl. destruct (find_matching is k a) eqn:Hf.
{ exists a. exists n. split.
reflexivity.
apply find_matching_correct with is. assumption. }
{ apply IHis. unfold has_pair. exists x'. exists y'.
repeat split; assumption. }
Qed.

View File

@@ -32,6 +32,7 @@ add_executable(compiler
binop.cpp binop.hpp binop.cpp binop.hpp
instruction.cpp instruction.hpp instruction.cpp instruction.hpp
graph.cpp graph.hpp graph.cpp graph.hpp
global_scope.cpp global_scope.hpp
${BISON_parser_OUTPUTS} ${BISON_parser_OUTPUTS}
${FLEX_scanner_OUTPUTS} ${FLEX_scanner_OUTPUTS}
main.cpp main.cpp

View File

@@ -3,6 +3,7 @@
#include "binop.hpp" #include "binop.hpp"
#include "error.hpp" #include "error.hpp"
#include "type_env.hpp" #include "type_env.hpp"
#include "env.hpp"
static void print_indent(int n, std::ostream& to) { static void print_indent(int n, std::ostream& to) {
while(n--) to << " "; while(n--) to << " ";
@@ -13,14 +14,19 @@ void ast_int::print(int indent, std::ostream& to) const {
to << "INT: " << value << std::endl; to << "INT: " << value << std::endl;
} }
void ast_int::find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into) { void ast_int::find_free(std::set<std::string>& into) {
this->env = env;
} }
type_ptr ast_int::typecheck(type_mgr& mgr) { type_ptr ast_int::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
return type_ptr(new type_app(env->lookup_type("Int"))); return type_ptr(new type_app(env->lookup_type("Int")));
} }
void ast_int::translate(global_scope& scope) {
}
void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushint(value))); into.push_back(instruction_ptr(new instruction_pushint(value)));
} }
@@ -30,20 +36,25 @@ void ast_lid::print(int indent, std::ostream& to) const {
to << "LID: " << id << std::endl; to << "LID: " << id << std::endl;
} }
void ast_lid::find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into) { void ast_lid::find_free(std::set<std::string>& into) {
this->env = env; into.insert(id);
if(env->lookup(id) == nullptr) into.insert(id);
} }
type_ptr ast_lid::typecheck(type_mgr& mgr) { type_ptr ast_lid::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
return env->lookup(id)->instantiate(mgr); return env->lookup(id)->instantiate(mgr);
} }
void ast_lid::translate(global_scope& scope) {
}
void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
auto mangled_name = this->env->get_mangled_name(id);
into.push_back(instruction_ptr( into.push_back(instruction_ptr(
env->has_variable(id) ? (env->has_variable(mangled_name) && !this->env->is_global(id)) ?
(instruction*) new instruction_push(env->get_offset(id)) : (instruction*) new instruction_push(env->get_offset(mangled_name)) :
(instruction*) new instruction_pushglobal(id))); (instruction*) new instruction_pushglobal(mangled_name)));
} }
void ast_uid::print(int indent, std::ostream& to) const { void ast_uid::print(int indent, std::ostream& to) const {
@@ -51,16 +62,22 @@ void ast_uid::print(int indent, std::ostream& to) const {
to << "UID: " << id << std::endl; to << "UID: " << id << std::endl;
} }
void ast_uid::find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into) { void ast_uid::find_free(std::set<std::string>& into) {
this->env = env;
} }
type_ptr ast_uid::typecheck(type_mgr& mgr) { type_ptr ast_uid::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
return env->lookup(id)->instantiate(mgr); return env->lookup(id)->instantiate(mgr);
} }
void ast_uid::translate(global_scope& scope) {
}
void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushglobal(id))); into.push_back(instruction_ptr(
new instruction_pushglobal(this->env->get_mangled_name(id))));
} }
void ast_binop::print(int indent, std::ostream& to) const { void ast_binop::print(int indent, std::ostream& to) const {
@@ -70,15 +87,15 @@ void ast_binop::print(int indent, std::ostream& to) const {
right->print(indent + 1, to); right->print(indent + 1, to);
} }
void ast_binop::find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into) { void ast_binop::find_free(std::set<std::string>& into) {
this->env = env; left->find_free(into);
left->find_free(mgr, env, into); right->find_free(into);
right->find_free(mgr, env, into);
} }
type_ptr ast_binop::typecheck(type_mgr& mgr) { type_ptr ast_binop::typecheck(type_mgr& mgr, type_env_ptr& env) {
type_ptr ltype = left->typecheck(mgr); this->env = env;
type_ptr rtype = right->typecheck(mgr); type_ptr ltype = left->typecheck(mgr, env);
type_ptr rtype = right->typecheck(mgr, env);
type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr); type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr);
if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op));
@@ -90,6 +107,11 @@ type_ptr ast_binop::typecheck(type_mgr& mgr) {
return return_type; return return_type;
} }
void ast_binop::translate(global_scope& scope) {
left->translate(scope);
right->translate(scope);
}
void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into); right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into); left->compile(env_ptr(new env_offset(1, env)), into);
@@ -106,15 +128,15 @@ void ast_app::print(int indent, std::ostream& to) const {
right->print(indent + 1, to); right->print(indent + 1, to);
} }
void ast_app::find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into) { void ast_app::find_free(std::set<std::string>& into) {
this->env = env; left->find_free(into);
left->find_free(mgr, env, into); right->find_free(into);
right->find_free(mgr, env, into);
} }
type_ptr ast_app::typecheck(type_mgr& mgr) { type_ptr ast_app::typecheck(type_mgr& mgr, type_env_ptr& env) {
type_ptr ltype = left->typecheck(mgr); this->env = env;
type_ptr rtype = right->typecheck(mgr); type_ptr ltype = left->typecheck(mgr, env);
type_ptr rtype = right->typecheck(mgr, env);
type_ptr return_type = mgr.new_type(); type_ptr return_type = mgr.new_type();
type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); type_ptr arrow = type_ptr(new type_arr(rtype, return_type));
@@ -122,6 +144,11 @@ type_ptr ast_app::typecheck(type_mgr& mgr) {
return return_type; return return_type;
} }
void ast_app::translate(global_scope& scope) {
left->translate(scope);
right->translate(scope);
}
void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into); right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into); left->compile(env_ptr(new env_offset(1, env)), into);
@@ -139,24 +166,30 @@ void ast_case::print(int indent, std::ostream& to) const {
} }
} }
void ast_case::find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into) { void ast_case::find_free(std::set<std::string>& into) {
this->env = env; of->find_free(into);
of->find_free(mgr, env, into);
for(auto& branch : branches) { for(auto& branch : branches) {
type_env_ptr new_env = type_scope(env); std::set<std::string> free_in_branch;
branch->pat->insert_bindings(mgr, new_env); std::set<std::string> pattern_variables;
branch->expr->find_free(mgr, new_env, into); branch->pat->find_variables(pattern_variables);
branch->expr->find_free(free_in_branch);
for(auto& free : free_in_branch) {
if(pattern_variables.find(free) == pattern_variables.end())
into.insert(free);
}
} }
} }
type_ptr ast_case::typecheck(type_mgr& mgr) { type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
type_var* var; type_var* var;
type_ptr case_type = mgr.resolve(of->typecheck(mgr), var); type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var);
type_ptr branch_type = mgr.new_type(); type_ptr branch_type = mgr.new_type();
for(auto& branch : branches) { for(auto& branch : branches) {
branch->pat->typecheck(case_type, mgr, branch->expr->env); type_env_ptr new_env = type_scope(env);
type_ptr curr_branch_type = branch->expr->typecheck(mgr); branch->pat->typecheck(case_type, mgr, new_env);
type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env);
mgr.unify(branch_type, curr_branch_type); mgr.unify(branch_type, curr_branch_type);
} }
@@ -170,6 +203,13 @@ type_ptr ast_case::typecheck(type_mgr& mgr) {
return branch_type; return branch_type;
} }
void ast_case::translate(global_scope& scope) {
of->translate(scope);
for(auto& branch : branches) {
branch->expr->translate(scope);
}
}
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
type_app* app_type = dynamic_cast<type_app*>(input_type.get()); type_app* app_type = dynamic_cast<type_app*>(input_type.get());
type_data* type = dynamic_cast<type_data*>(app_type->constructor.get()); type_data* type = dynamic_cast<type_data*>(app_type->constructor.get());
@@ -199,7 +239,7 @@ void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) c
} else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) { } else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) {
env_ptr new_env = env; env_ptr new_env = env;
for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) { for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env)); new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env));
} }
branch_instructions.push_back(instruction_ptr(new instruction_split( branch_instructions.push_back(instruction_ptr(new instruction_split(
@@ -226,16 +266,145 @@ void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) c
} }
} }
void ast_let::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "LET: " << std::endl;
in->print(indent + 1, to);
}
void ast_let::find_free(std::set<std::string>& into) {
definitions.find_free(into);
std::set<std::string> all_free;
in->find_free(all_free);
for(auto& free_var : all_free) {
if(definitions.defs_defn.find(free_var) == definitions.defs_defn.end())
into.insert(free_var);
}
}
type_ptr ast_let::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
definitions.typecheck(mgr, env);
return in->typecheck(mgr, definitions.env);
}
void ast_let::translate(global_scope& scope) {
for(auto& def : definitions.defs_data) {
def.second->into_globals(scope);
}
for(auto& def : definitions.defs_defn) {
size_t original_params = def.second->params.size();
std::string original_name = def.second->name;
auto& global_definition = def.second->into_global(scope);
size_t captured = global_definition.params.size() - original_params;
type_env_ptr mangled_env = type_scope(env);
mangled_env->bind(def.first, env->lookup(def.first), visibility::global);
mangled_env->set_mangled_name(def.first, global_definition.name);
ast_ptr global_app(new ast_lid(original_name));
global_app->env = mangled_env;
for(auto& param : global_definition.params) {
if(!(captured--)) break;
ast_ptr new_arg(new ast_lid(param));
new_arg->env = env;
global_app = ast_ptr(new ast_app(std::move(global_app), std::move(new_arg)));
global_app->env = env;
}
translated_definitions.push_back({ def.first, std::move(global_app) });
}
in->translate(scope);
}
void ast_let::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_alloc(translated_definitions.size())));
env_ptr new_env = env;
for(auto& def : translated_definitions) {
new_env = env_ptr(new env_var(definitions.env->get_mangled_name(def.first), std::move(new_env)));
}
int offset = translated_definitions.size() - 1;
for(auto& def : translated_definitions) {
def.second->compile(new_env, into);
into.push_back(instruction_ptr(new instruction_update(offset--)));
}
in->compile(new_env, into);
into.push_back(instruction_ptr(new instruction_slide(translated_definitions.size())));
}
void ast_lambda::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "LAMBDA";
for(auto& param : params) {
to << " " << param;
}
to << std::endl;
body->print(indent+1, to);
}
void ast_lambda::find_free(std::set<std::string>& into) {
body->find_free(free_variables);
for(auto& param : params) {
free_variables.erase(param);
}
into.insert(free_variables.begin(), free_variables.end());
}
type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
var_env = type_scope(env);
type_ptr return_type = mgr.new_type();
type_ptr full_type = return_type;
for(auto it = params.rbegin(); it != params.rend(); it++) {
type_ptr param_type = mgr.new_type();
var_env->bind(*it, param_type);
full_type = type_ptr(new type_arr(std::move(param_type), full_type));
}
mgr.unify(return_type, body->typecheck(mgr, var_env));
return full_type;
}
void ast_lambda::translate(global_scope& scope) {
std::vector<std::string> function_params;
for(auto& free_variable : free_variables) {
if(env->is_global(free_variable)) continue;
function_params.push_back(free_variable);
}
size_t captured_count = function_params.size();
function_params.insert(function_params.end(), params.begin(), params.end());
auto& new_function = scope.add_function("lambda", std::move(function_params), std::move(body));
type_env_ptr mangled_env = type_scope(env);
mangled_env->bind("lambda", type_scheme_ptr(nullptr), visibility::global);
mangled_env->set_mangled_name("lambda", new_function.name);
ast_ptr new_application = ast_ptr(new ast_lid("lambda"));
new_application->env = mangled_env;
for(auto& param : new_function.params) {
if(!(captured_count--)) break;
ast_ptr new_arg = ast_ptr(new ast_lid(param));
new_arg->env = env;
new_application = ast_ptr(new ast_app(std::move(new_application), std::move(new_arg)));
new_application->env = env;
}
translated = std::move(new_application);
}
void ast_lambda::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
translated->compile(env, into);
}
void pattern_var::print(std::ostream& to) const { void pattern_var::print(std::ostream& to) const {
to << var; to << var;
} }
void pattern_var::insert_bindings(type_mgr& mgr, type_env_ptr& env) const { void pattern_var::find_variables(std::set<std::string>& into) const {
env->bind(var, mgr.new_type()); into.insert(var);
} }
void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
mgr.unify(env->lookup(var)->instantiate(mgr), t); env->bind(var, t);
} }
void pattern_constr::print(std::ostream& to) const { void pattern_constr::print(std::ostream& to) const {
@@ -245,23 +414,22 @@ void pattern_constr::print(std::ostream& to) const {
} }
} }
void pattern_constr::insert_bindings(type_mgr& mgr, type_env_ptr& env) const { void pattern_constr::find_variables(std::set<std::string>& into) const {
for(auto& param : params) { into.insert(params.begin(), params.end());
env->bind(param, mgr.new_type());
}
} }
void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
type_ptr constructor_type = env->lookup(constr)->instantiate(mgr); type_scheme_ptr constructor_type_scheme = env->lookup(constr);
if(!constructor_type) { if(!constructor_type_scheme) {
throw type_error(std::string("pattern using unknown constructor ") + constr); throw type_error(std::string("pattern using unknown constructor ") + constr);
} }
type_ptr constructor_type = constructor_type_scheme->instantiate(mgr);
for(auto& param : params) { for(auto& param : params) {
type_arr* arr = dynamic_cast<type_arr*>(constructor_type.get()); type_arr* arr = dynamic_cast<type_arr*>(constructor_type.get());
if(!arr) throw type_error("too many parameters in constructor pattern"); if(!arr) throw type_error("too many parameters in constructor pattern");
mgr.unify(env->lookup(param)->instantiate(mgr), arr->left); env->bind(param, arr->left);
constructor_type = arr->right; constructor_type = arr->right;
} }

View File

@@ -7,6 +7,8 @@
#include "binop.hpp" #include "binop.hpp"
#include "instruction.hpp" #include "instruction.hpp"
#include "env.hpp" #include "env.hpp"
#include "definition.hpp"
#include "global_scope.hpp"
struct ast { struct ast {
type_env_ptr env; type_env_ptr env;
@@ -14,9 +16,9 @@ struct ast {
virtual ~ast() = default; virtual ~ast() = default;
virtual void print(int indent, std::ostream& to) const = 0; virtual void print(int indent, std::ostream& to) const = 0;
virtual void find_free(type_mgr& mgr, virtual void find_free(std::set<std::string>& into) = 0;
type_env_ptr& env, std::set<std::string>& into) = 0; virtual type_ptr typecheck(type_mgr& mgr, type_env_ptr& env) = 0;
virtual type_ptr typecheck(type_mgr& mgr) = 0; virtual void translate(global_scope& scope) = 0;
virtual void compile(const env_ptr& env, virtual void compile(const env_ptr& env,
std::vector<instruction_ptr>& into) const = 0; std::vector<instruction_ptr>& into) const = 0;
}; };
@@ -27,7 +29,7 @@ struct pattern {
virtual ~pattern() = default; virtual ~pattern() = default;
virtual void print(std::ostream& to) const = 0; virtual void print(std::ostream& to) const = 0;
virtual void insert_bindings(type_mgr& mgr, type_env_ptr& env) const = 0; virtual void find_variables(std::set<std::string>& into) const = 0;
virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const = 0; virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const = 0;
}; };
@@ -50,8 +52,9 @@ struct ast_int : public ast {
: value(v) {} : value(v) {}
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
void find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into); void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr); type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@@ -62,8 +65,9 @@ struct ast_lid : public ast {
: id(std::move(i)) {} : id(std::move(i)) {}
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
void find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into); void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr); type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@@ -74,8 +78,9 @@ struct ast_uid : public ast {
: id(std::move(i)) {} : id(std::move(i)) {}
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
void find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into); void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr); type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@@ -88,8 +93,9 @@ struct ast_binop : public ast {
: op(o), left(std::move(l)), right(std::move(r)) {} : op(o), left(std::move(l)), right(std::move(r)) {}
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
void find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into); void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr); type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@@ -101,8 +107,9 @@ struct ast_app : public ast {
: left(std::move(l)), right(std::move(r)) {} : left(std::move(l)), right(std::move(r)) {}
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
void find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into); void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr); type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@@ -115,8 +122,46 @@ struct ast_case : public ast {
: of(std::move(o)), branches(std::move(b)) {} : of(std::move(o)), branches(std::move(b)) {}
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
void find_free(type_mgr& mgr, type_env_ptr& env, std::set<std::string>& into); void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr); type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_let : public ast {
using basic_definition = std::pair<std::string, ast_ptr>;
definition_group definitions;
ast_ptr in;
std::vector<basic_definition> translated_definitions;
ast_let(definition_group g, ast_ptr i)
: definitions(std::move(g)), in(std::move(i)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_lambda : public ast {
std::vector<std::string> params;
ast_ptr body;
type_env_ptr var_env;
std::set<std::string> free_variables;
ast_ptr translated;
ast_lambda(std::vector<std::string> ps, ast_ptr b)
: params(std::move(ps)), body(std::move(b)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@@ -127,7 +172,7 @@ struct pattern_var : public pattern {
: var(std::move(v)) {} : var(std::move(v)) {}
void print(std::ostream &to) const; void print(std::ostream &to) const;
void insert_bindings(type_mgr& mgr, type_env_ptr& env) const; void find_variables(std::set<std::string>& into) const;
void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const; void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const;
}; };
@@ -139,6 +184,6 @@ struct pattern_constr : public pattern {
: constr(std::move(c)), params(std::move(p)) {} : constr(std::move(c)), params(std::move(p)) {}
void print(std::ostream &to) const; void print(std::ostream &to) const;
virtual void insert_bindings(type_mgr& mgr, type_env_ptr& env) const; void find_variables(std::set<std::string>& into) const;
virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const; virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const;
}; };

View File

@@ -5,13 +5,20 @@
#include "llvm_context.hpp" #include "llvm_context.hpp"
#include "type.hpp" #include "type.hpp"
#include "type_env.hpp" #include "type_env.hpp"
#include "graph.hpp"
#include <llvm/IR/DerivedTypes.h> #include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h> #include <llvm/IR/Function.h>
#include <llvm/IR/Type.h> #include <llvm/IR/Type.h>
void definition_defn::find_free(type_mgr& mgr, type_env_ptr& env) { void definition_defn::find_free() {
this->env = env; body->find_free(free_variables);
for(auto& param : params) {
free_variables.erase(param);
}
}
void definition_defn::insert_types(type_mgr& mgr, type_env_ptr& env, visibility v) {
this->env = env;
var_env = type_scope(env); var_env = type_scope(env);
return_type = mgr.new_type(); return_type = mgr.new_type();
full_type = return_type; full_type = return_type;
@@ -21,39 +28,24 @@ void definition_defn::find_free(type_mgr& mgr, type_env_ptr& env) {
full_type = type_ptr(new type_arr(param_type, full_type)); full_type = type_ptr(new type_arr(param_type, full_type));
var_env->bind(*it, param_type); var_env->bind(*it, param_type);
} }
env->bind(name, full_type, v);
body->find_free(mgr, var_env, free_variables);
}
void definition_defn::insert_types(type_mgr& mgr) {
env->bind(name, full_type);
} }
void definition_defn::typecheck(type_mgr& mgr) { void definition_defn::typecheck(type_mgr& mgr) {
type_ptr body_type = body->typecheck(mgr); type_ptr body_type = body->typecheck(mgr, var_env);
mgr.unify(return_type, body_type); mgr.unify(return_type, body_type);
} }
void definition_defn::compile() {
env_ptr new_env = env_ptr(new env_offset(0, nullptr));
for(auto it = params.rbegin(); it != params.rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env));
}
body->compile(new_env, instructions);
instructions.push_back(instruction_ptr(new instruction_update(params.size())));
instructions.push_back(instruction_ptr(new instruction_pop(params.size())));
}
void definition_defn::declare_llvm(llvm_context& ctx) { global_function& definition_defn::into_global(global_scope& scope) {
generated_function = ctx.create_custom_function(name, params.size()); std::vector<std::string> all_params;
for(auto& free : free_variables) {
if(env->is_global(free)) continue;
all_params.push_back(free);
} }
all_params.insert(all_params.end(), params.begin(), params.end());
void definition_defn::generate_llvm(llvm_context& ctx) { body->translate(scope);
ctx.builder.SetInsertPoint(&generated_function->getEntryBlock()); return scope.add_function(name, std::move(all_params), std::move(body));
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, generated_function);
}
ctx.builder.CreateRetVoid();
} }
void definition_data::insert_types(type_env_ptr& env) { void definition_data::insert_types(type_env_ptr& env) {
@@ -91,19 +83,63 @@ void definition_data::insert_constructors() const {
} }
} }
void definition_data::generate_llvm(llvm_context& ctx) { void definition_data::into_globals(global_scope& scope) {
for(auto& constructor : constructors) { for(auto& constructor : constructors) {
auto new_function = global_constructor& c = scope.add_constructor(
ctx.create_custom_function(constructor->name, constructor->types.size()); constructor->name, constructor->tag, constructor->types.size());
std::vector<instruction_ptr> instructions; env->set_mangled_name(constructor->name, c.name);
instructions.push_back(instruction_ptr( }
new instruction_pack(constructor->tag, constructor->types.size()) }
));
instructions.push_back(instruction_ptr(new instruction_update(0))); void definition_group::find_free(std::set<std::string>& into) {
ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); for(auto& def_pair : defs_defn) {
for (auto& instruction : instructions) { def_pair.second->find_free();
instruction->gen_llvm(ctx, new_function); for(auto& free_var : def_pair.second->free_variables) {
} if(defs_defn.find(free_var) == defs_defn.end()) {
ctx.builder.CreateRetVoid(); into.insert(free_var);
} else {
def_pair.second->nearby_variables.insert(free_var);
}
}
}
}
void definition_group::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = type_scope(env);
for(auto& def_data : defs_data) {
def_data.second->insert_types(this->env);
}
for(auto& def_data : defs_data) {
def_data.second->insert_constructors();
}
function_graph dependency_graph;
for(auto& def_defn : defs_defn) {
def_defn.second->find_free();
dependency_graph.add_function(def_defn.second->name);
for(auto& dependency : def_defn.second->nearby_variables) {
if(defs_defn.find(dependency) == defs_defn.end())
throw 0;
dependency_graph.add_edge(def_defn.second->name, dependency);
}
}
std::vector<group_ptr> groups = dependency_graph.compute_order();
for(auto it = groups.rbegin(); it != groups.rend(); it++) {
auto& group = *it;
for(auto& def_defnn_name : group->members) {
auto& def_defn = defs_defn.find(def_defnn_name)->second;
def_defn->insert_types(mgr, this->env, vis);
}
for(auto& def_defnn_name : group->members) {
auto& def_defn = defs_defn.find(def_defnn_name)->second;
def_defn->typecheck(mgr);
}
for(auto& def_defnn_name : group->members) {
this->env->generalize(def_defnn_name, *group, mgr);
}
} }
} }

View File

@@ -1,11 +1,13 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <map>
#include <set> #include <set>
#include "instruction.hpp" #include "instruction.hpp"
#include "llvm_context.hpp" #include "llvm_context.hpp"
#include "parsed_type.hpp" #include "parsed_type.hpp"
#include "type_env.hpp" #include "type_env.hpp"
#include "global_scope.hpp"
struct ast; struct ast;
using ast_ptr = std::unique_ptr<ast>; using ast_ptr = std::unique_ptr<ast>;
@@ -29,24 +31,20 @@ struct definition_defn {
type_env_ptr env; type_env_ptr env;
type_env_ptr var_env; type_env_ptr var_env;
std::set<std::string> free_variables; std::set<std::string> free_variables;
std::set<std::string> nearby_variables;
type_ptr full_type; type_ptr full_type;
type_ptr return_type; type_ptr return_type;
std::vector<instruction_ptr> instructions;
llvm::Function* generated_function;
definition_defn(std::string n, std::vector<std::string> p, ast_ptr b) definition_defn(std::string n, std::vector<std::string> p, ast_ptr b)
: name(std::move(n)), params(std::move(p)), body(std::move(b)) { : name(std::move(n)), params(std::move(p)), body(std::move(b)) {
} }
void find_free(type_mgr& mgr, type_env_ptr& env); void find_free();
void insert_types(type_mgr& mgr); void insert_types(type_mgr& mgr, type_env_ptr& env, visibility v);
void typecheck(type_mgr& mgr); void typecheck(type_mgr& mgr);
void compile();
void declare_llvm(llvm_context& ctx); global_function& into_global(global_scope& scope);
void generate_llvm(llvm_context& ctx);
}; };
using definition_defn_ptr = std::unique_ptr<definition_defn>; using definition_defn_ptr = std::unique_ptr<definition_defn>;
@@ -66,7 +64,20 @@ struct definition_data {
void insert_types(type_env_ptr& env); void insert_types(type_env_ptr& env);
void insert_constructors() const; void insert_constructors() const;
void generate_llvm(llvm_context& ctx);
void into_globals(global_scope& scope);
}; };
using definition_data_ptr = std::unique_ptr<definition_data>; using definition_data_ptr = std::unique_ptr<definition_data>;
struct definition_group {
std::map<std::string, definition_data_ptr> defs_data;
std::map<std::string, definition_defn_ptr> defs_defn;
visibility vis;
type_env_ptr env;
definition_group(visibility v = visibility::local) : vis(v) {}
void find_free(std::set<std::string>& into);
void typecheck(type_mgr& mgr, type_env_ptr& env);
};

View File

@@ -15,7 +15,7 @@ struct env_var : public env {
std::string name; std::string name;
env_ptr parent; env_ptr parent;
env_var(std::string& n, env_ptr p) env_var(std::string n, env_ptr p)
: name(std::move(n)), parent(std::move(p)) {} : name(std::move(n)), parent(std::move(p)) {}
int get_offset(const std::string& name) const; int get_offset(const std::string& name) const;

View File

@@ -0,0 +1,17 @@
data List a = { Nil, Cons a (List a) }
defn fix f = { let { defn x = { f x } } in { x } }
defn fixpointOnes fo = { Cons 1 fo }
defn sumTwo l = {
case l of {
Nil -> { 0 }
Cons x xs -> {
x + case xs of {
Nil -> { 0 }
Cons y ys -> { y }
}
}
}
}
defn main = { sumTwo (fix fixpointOnes) }

View File

@@ -0,0 +1,19 @@
data List a = { Nil, Cons a (List a) }
defn sum l = {
case l of {
Nil -> { 0 }
Cons x xs -> { x + sum xs}
}
}
defn map f l = {
case l of {
Nil -> { Nil }
Cons x xs -> { Cons (f x) (map f xs) }
}
}
defn main = {
sum (map \x -> { x * x } (map (\x -> { x + x }) (Cons 1 (Cons 2 (Cons 3 Nil)))))
}

View File

@@ -0,0 +1,47 @@
data Bool = { True, False }
data List a = { Nil, Cons a (List a) }
defn if c t e = {
case c of {
True -> { t }
False -> { e }
}
}
defn mergeUntil l r p = {
let {
defn mergeLeft nl nr = {
case nl of {
Nil -> { Nil }
Cons x xs -> { if (p x) (Cons x (mergeRight xs nr)) Nil }
}
}
defn mergeRight nl nr = {
case nr of {
Nil -> { Nil }
Cons x xs -> { if (p x) (Cons x (mergeLeft nl xs)) Nil }
}
}
} in {
mergeLeft l r
}
}
defn const x y = { x }
defn sum l = {
case l of {
Nil -> { 0 }
Cons x xs -> { x + sum xs }
}
}
defn main = {
let {
defn firstList = { Cons 1 (Cons 3 (Cons 5 Nil)) }
defn secondList = { Cons 2 (Cons 4 (Cons 6 Nil)) }
} in {
sum (mergeUntil firstList secondList (const True))
}
}

View File

@@ -0,0 +1,23 @@
data Pair a b = { Pair a b }
defn packer = {
let {
data Packed a = { Packed a }
defn pack a = { Packed a }
defn unpack p = {
case p of {
Packed a -> { a }
}
}
} in {
Pair pack unpack
}
}
defn main = {
case packer of {
Pair pack unpack -> {
unpack (pack 3)
}
}
}

View File

@@ -0,0 +1,83 @@
#include "global_scope.hpp"
#include "ast.hpp"
void global_function::compile() {
env_ptr new_env = env_ptr(new env_offset(0, nullptr));
for(auto it = params.rbegin(); it != params.rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env));
}
body->compile(new_env, instructions);
instructions.push_back(instruction_ptr(new instruction_update(params.size())));
instructions.push_back(instruction_ptr(new instruction_pop(params.size())));
}
void global_function::declare_llvm(llvm_context& ctx) {
generated_function = ctx.create_custom_function(name, params.size());
}
void global_function::generate_llvm(llvm_context& ctx) {
ctx.builder.SetInsertPoint(&generated_function->getEntryBlock());
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, generated_function);
}
ctx.builder.CreateRetVoid();
}
void global_constructor::generate_llvm(llvm_context& ctx) {
auto new_function =
ctx.create_custom_function(name, arity);
std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_pack(tag, arity)));
instructions.push_back(instruction_ptr(new instruction_update(0)));
ctx.builder.SetInsertPoint(&new_function->getEntryBlock());
for (auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function);
}
ctx.builder.CreateRetVoid();
}
global_function& global_scope::add_function(std::string n, std::vector<std::string> ps, ast_ptr b) {
global_function* new_function = new global_function(mangle_name(n), std::move(ps), std::move(b));
functions.push_back(global_function_ptr(new_function));
return *new_function;
}
global_constructor& global_scope::add_constructor(std::string n, int8_t t, size_t a) {
global_constructor* new_constructor = new global_constructor(mangle_name(n), t, a);
constructors.push_back(global_constructor_ptr(new_constructor));
return *new_constructor;
}
void global_scope::compile() {
for(auto& function : functions) {
function->compile();
}
}
void global_scope::generate_llvm(llvm_context& ctx) {
for(auto& constructor : constructors) {
constructor->generate_llvm(ctx);
}
for(auto& function : functions) {
function->declare_llvm(ctx);
}
for(auto& function : functions) {
function->generate_llvm(ctx);
}
}
std::string global_scope::mangle_name(const std::string& n) {
auto occurence_it = occurence_count.find(n);
int occurence = 0;
if(occurence_it != occurence_count.end()) {
occurence = occurence_it->second + 1;
}
occurence_count[n] = occurence;
std::string final_name = n;
if (occurence != 0) {
final_name += "_";
final_name += std::to_string(occurence);
}
return final_name;
}

View File

@@ -0,0 +1,55 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include <llvm/IR/Function.h>
#include "instruction.hpp"
struct ast;
using ast_ptr = std::unique_ptr<ast>;
struct global_function {
std::string name;
std::vector<std::string> params;
ast_ptr body;
std::vector<instruction_ptr> instructions;
llvm::Function* generated_function;
global_function(std::string n, std::vector<std::string> ps, ast_ptr b)
: name(std::move(n)), params(std::move(ps)), body(std::move(b)) {}
void compile();
void declare_llvm(llvm_context& ctx);
void generate_llvm(llvm_context& ctx);
};
using global_function_ptr = std::unique_ptr<global_function>;
struct global_constructor {
std::string name;
int8_t tag;
size_t arity;
global_constructor(std::string n, int8_t t, size_t a)
: name(std::move(n)), tag(t), arity(a) {}
void generate_llvm(llvm_context& ctx);
};
using global_constructor_ptr = std::unique_ptr<global_constructor>;
struct global_scope {
std::map<std::string, int> occurence_count;
std::vector<global_function_ptr> functions;
std::vector<global_constructor_ptr> constructors;
global_function& add_function(std::string n, std::vector<std::string> ps, ast_ptr b);
global_constructor& add_constructor(std::string n, int8_t t, size_t a);
void compile();
void generate_llvm(llvm_context& ctx);
private:
std::string mangle_name(const std::string& n);
};

View File

@@ -7,7 +7,6 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <iostream>
using function = std::string; using function = std::string;

View File

@@ -21,12 +21,10 @@ void yy::parser::error(const std::string& msg) {
std::cout << "An error occured: " << msg << std::endl; std::cout << "An error occured: " << msg << std::endl;
} }
extern std::map<std::string, definition_data_ptr> defs_data; extern definition_group global_defs;
extern std::map<std::string, definition_defn_ptr> defs_defn;
void typecheck_program( void typecheck_program(
const std::map<std::string, definition_data_ptr>& defs_data, definition_group& defs,
const std::map<std::string, definition_defn_ptr>& defs_defn,
type_mgr& mgr, type_env_ptr& env) { type_mgr& mgr, type_env_ptr& env) {
type_ptr int_type = type_ptr(new type_base("Int")); type_ptr int_type = type_ptr(new type_base("Int"));
env->bind_type("Int", int_type); env->bind_type("Int", int_type);
@@ -35,63 +33,32 @@ void typecheck_program(
type_ptr binop_type = type_ptr(new type_arr( type_ptr binop_type = type_ptr(new type_arr(
int_type_app, int_type_app,
type_ptr(new type_arr(int_type_app, int_type_app)))); type_ptr(new type_arr(int_type_app, int_type_app))));
env->bind("+", binop_type); env->bind("+", binop_type, visibility::global);
env->bind("-", binop_type); env->bind("-", binop_type, visibility::global);
env->bind("*", binop_type); env->bind("*", binop_type, visibility::global);
env->bind("/", binop_type); env->bind("/", binop_type, visibility::global);
for(auto& def_data : defs_data) { std::set<std::string> free;
def_data.second->insert_types(env); defs.find_free(free);
} defs.typecheck(mgr, env);
for(auto& def_data : defs_data) {
def_data.second->insert_constructors();
}
function_graph dependency_graph; for(auto& pair : defs.env->names) {
for(auto& def_defn : defs_defn) {
def_defn.second->find_free(mgr, env);
dependency_graph.add_function(def_defn.second->name);
for(auto& dependency : def_defn.second->free_variables) {
if(defs_defn.find(dependency) == defs_defn.end())
throw 0;
dependency_graph.add_edge(def_defn.second->name, dependency);
}
}
std::vector<group_ptr> groups = dependency_graph.compute_order();
for(auto it = groups.rbegin(); it != groups.rend(); it++) {
auto& group = *it;
for(auto& def_defnn_name : group->members) {
auto& def_defn = defs_defn.find(def_defnn_name)->second;
def_defn->insert_types(mgr);
}
for(auto& def_defnn_name : group->members) {
auto& def_defn = defs_defn.find(def_defnn_name)->second;
def_defn->typecheck(mgr);
}
for(auto& def_defnn_name : group->members) {
env->generalize(def_defnn_name, mgr);
}
}
for(auto& pair : env->names) {
std::cout << pair.first << ": "; std::cout << pair.first << ": ";
pair.second->print(mgr, std::cout); pair.second.type->print(mgr, std::cout);
std::cout << std::endl; std::cout << std::endl;
} }
} }
void compile_program(const std::map<std::string, definition_defn_ptr>& defs_defn) { global_scope translate_program(definition_group& group) {
for(auto& def_defn : defs_defn) { global_scope scope;
def_defn.second->compile(); for(auto& data : group.defs_data) {
data.second->into_globals(scope);
for(auto& instruction : def_defn.second->instructions) {
instruction->print(0, std::cout);
} }
std::cout << std::endl; for(auto& defn : group.defs_defn) {
auto& function = defn.second->into_global(scope);
function.body->env->parent->set_mangled_name(defn.first, function.name);
} }
return scope;
} }
void gen_llvm_internal_op(llvm_context& ctx, binop op) { void gen_llvm_internal_op(llvm_context& ctx, binop op) {
@@ -151,24 +118,14 @@ void output_llvm(llvm_context& ctx, const std::string& filename) {
} }
} }
void gen_llvm( void gen_llvm(global_scope& scope) {
const std::map<std::string, definition_data_ptr>& defs_data,
const std::map<std::string, definition_defn_ptr>& defs_defn) {
llvm_context ctx; llvm_context ctx;
gen_llvm_internal_op(ctx, PLUS); gen_llvm_internal_op(ctx, PLUS);
gen_llvm_internal_op(ctx, MINUS); gen_llvm_internal_op(ctx, MINUS);
gen_llvm_internal_op(ctx, TIMES); gen_llvm_internal_op(ctx, TIMES);
gen_llvm_internal_op(ctx, DIVIDE); gen_llvm_internal_op(ctx, DIVIDE);
for(auto& def_data : defs_data) { scope.generate_llvm(ctx);
def_data.second->generate_llvm(ctx);
}
for(auto& def_defn : defs_defn) {
def_defn.second->declare_llvm(ctx);
}
for(auto& def_defn : defs_defn) {
def_defn.second->generate_llvm(ctx);
}
ctx.module.print(llvm::outs(), nullptr); ctx.module.print(llvm::outs(), nullptr);
output_llvm(ctx, "program.o"); output_llvm(ctx, "program.o");
@@ -180,7 +137,7 @@ int main() {
type_env_ptr env(new type_env); type_env_ptr env(new type_env);
parser.parse(); parser.parse();
for(auto& def_defn : defs_defn) { for(auto& def_defn : global_defs.defs_defn) {
std::cout << def_defn.second->name; std::cout << def_defn.second->name;
for(auto& param : def_defn.second->params) std::cout << " " << param; for(auto& param : def_defn.second->params) std::cout << " " << param;
std::cout << ":" << std::endl; std::cout << ":" << std::endl;
@@ -188,9 +145,10 @@ int main() {
def_defn.second->body->print(1, std::cout); def_defn.second->body->print(1, std::cout);
} }
try { try {
typecheck_program(defs_data, defs_defn, mgr, env); typecheck_program(global_defs, mgr, env);
compile_program(defs_defn); global_scope scope = translate_program(global_defs);
gen_llvm(defs_data, defs_defn); scope.compile();
gen_llvm(scope);
} catch(unification_error& err) { } catch(unification_error& err) {
std::cout << "failed to unify types: " << std::endl; std::cout << "failed to unify types: " << std::endl;
std::cout << " (1) \033[34m"; std::cout << " (1) \033[34m";

View File

@@ -7,13 +7,13 @@
#include "parser.hpp" #include "parser.hpp"
#include "parsed_type.hpp" #include "parsed_type.hpp"
std::map<std::string, definition_data_ptr> defs_data; definition_group global_defs;
std::map<std::string, definition_defn_ptr> defs_defn;
extern yy::parser::symbol_type yylex(); extern yy::parser::symbol_type yylex();
%} %}
%token BACKSLASH
%token PLUS %token PLUS
%token TIMES %token TIMES
%token MINUS %token MINUS
@@ -23,6 +23,8 @@ extern yy::parser::symbol_type yylex();
%token DATA %token DATA
%token CASE %token CASE
%token OF %token OF
%token LET
%token IN
%token OCURLY %token OCURLY
%token CCURLY %token CCURLY
%token OPAREN %token OPAREN
@@ -41,8 +43,9 @@ extern yy::parser::symbol_type yylex();
%type <std::vector<branch_ptr>> branches %type <std::vector<branch_ptr>> branches
%type <std::vector<constructor_ptr>> constructors %type <std::vector<constructor_ptr>> constructors
%type <std::vector<parsed_type_ptr>> typeList %type <std::vector<parsed_type_ptr>> typeList
%type <definition_group> definitions
%type <parsed_type_ptr> type nonArrowType typeListElement %type <parsed_type_ptr> type nonArrowType typeListElement
%type <ast_ptr> aAdd aMul case app appBase %type <ast_ptr> aAdd aMul case let lambda app appBase
%type <definition_data_ptr> data %type <definition_data_ptr> data
%type <definition_defn_ptr> defn %type <definition_defn_ptr> defn
%type <branch_ptr> branch %type <branch_ptr> branch
@@ -54,17 +57,13 @@ extern yy::parser::symbol_type yylex();
%% %%
program program
: definitions { } : definitions { global_defs = std::move($1); global_defs.vis = visibility::global; }
; ;
definitions definitions
: definitions definition { } : definitions defn { $$ = std::move($1); auto name = $2->name; $$.defs_defn[name] = std::move($2); }
| definition { } | definitions data { $$ = std::move($1); auto name = $2->name; $$.defs_data[name] = std::move($2); }
; | %empty { $$ = definition_group(); }
definition
: defn { auto name = $1->name; defs_defn[name] = std::move($1); }
| data { auto name = $1->name; defs_data[name] = std::move($1); }
; ;
defn defn
@@ -101,6 +100,18 @@ appBase
| UID { $$ = ast_ptr(new ast_uid(std::move($1))); } | UID { $$ = ast_ptr(new ast_uid(std::move($1))); }
| OPAREN aAdd CPAREN { $$ = std::move($2); } | OPAREN aAdd CPAREN { $$ = std::move($2); }
| case { $$ = std::move($1); } | case { $$ = std::move($1); }
| let { $$ = std::move($1); }
| lambda { $$ = std::move($1); }
;
let
: LET OCURLY definitions CCURLY IN OCURLY aAdd CCURLY
{ $$ = ast_ptr(new ast_let(std::move($3), std::move($7))); }
;
lambda
: BACKSLASH lowercaseParams ARROW OCURLY aAdd CCURLY
{ $$ = ast_ptr(new ast_lambda(std::move($2), std::move($5))); }
; ;
case case

View File

@@ -13,6 +13,7 @@
%% %%
[ \n]+ {} [ \n]+ {}
\\ { return yy::parser::make_BACKSLASH(); }
\+ { return yy::parser::make_PLUS(); } \+ { return yy::parser::make_PLUS(); }
\* { return yy::parser::make_TIMES(); } \* { return yy::parser::make_TIMES(); }
- { return yy::parser::make_MINUS(); } - { return yy::parser::make_MINUS(); }
@@ -22,6 +23,8 @@ defn { return yy::parser::make_DEFN(); }
data { return yy::parser::make_DATA(); } data { return yy::parser::make_DATA(); }
case { return yy::parser::make_CASE(); } case { return yy::parser::make_CASE(); }
of { return yy::parser::make_OF(); } of { return yy::parser::make_OF(); }
let { return yy::parser::make_LET(); }
in { return yy::parser::make_IN(); }
\{ { return yy::parser::make_OCURLY(); } \{ { return yy::parser::make_OCURLY(); }
\} { return yy::parser::make_CCURLY(); } \} { return yy::parser::make_CCURLY(); }
\( { return yy::parser::make_OPAREN(); } \( { return yy::parser::make_OPAREN(); }

View File

@@ -5,6 +5,8 @@
#include <vector> #include <vector>
#include "error.hpp" #include "error.hpp"
bool type::is_arrow(const type_mgr& mgr) const { return false; }
void type_scheme::print(const type_mgr& mgr, std::ostream& to) const { void type_scheme::print(const type_mgr& mgr, std::ostream& to) const {
if(forall.size() != 0) { if(forall.size() != 0) {
to << "forall "; to << "forall ";
@@ -34,15 +36,30 @@ void type_var::print(const type_mgr& mgr, std::ostream& to) const {
} }
} }
bool type_var::is_arrow(const type_mgr& mgr) const {
auto it = mgr.types.find(name);
if(it != mgr.types.end()) {
return it->second->is_arrow(mgr);
} else {
return false;
}
}
void type_base::print(const type_mgr& mgr, std::ostream& to) const { void type_base::print(const type_mgr& mgr, std::ostream& to) const {
to << name; to << name;
} }
void type_arr::print(const type_mgr& mgr, std::ostream& to) const { void type_arr::print(const type_mgr& mgr, std::ostream& to) const {
bool print_parenths = left->is_arrow(mgr);
if(print_parenths) to << "(";
left->print(mgr, to); left->print(mgr, to);
to << " -> ("; if(print_parenths) to << ")";
to << " -> ";
right->print(mgr, to); right->print(mgr, to);
to << ")"; }
bool type_arr::is_arrow(const type_mgr& mgr) const {
return true;
} }
void type_app::print(const type_mgr& mgr, std::ostream& to) const { void type_app::print(const type_mgr& mgr, std::ostream& to) const {
@@ -185,3 +202,18 @@ void type_mgr::find_free(const type_ptr& t, std::set<std::string>& into) const {
for(auto& arg : app->arguments) find_free(arg, into); for(auto& arg : app->arguments) find_free(arg, into);
} }
} }
void type_mgr::find_free(const type_scheme_ptr& t, std::set<std::string>& into) const {
std::set<std::string> monotype_free;
type_mgr limited_mgr;
for(auto& binding : types) {
auto existing_position = std::find(t->forall.begin(), t->forall.end(), binding.first);
if(existing_position != t->forall.end()) continue;
limited_mgr.types[binding.first] = binding.second;
}
limited_mgr.find_free(t->monotype, monotype_free);
for(auto& not_free : t->forall) {
monotype_free.erase(not_free);
}
into.insert(monotype_free.begin(), monotype_free.end());
}

View File

@@ -11,6 +11,7 @@ struct type {
virtual ~type() = default; virtual ~type() = default;
virtual void print(const type_mgr& mgr, std::ostream& to) const = 0; virtual void print(const type_mgr& mgr, std::ostream& to) const = 0;
virtual bool is_arrow(const type_mgr& mgr) const;
}; };
using type_ptr = std::shared_ptr<type>; using type_ptr = std::shared_ptr<type>;
@@ -34,6 +35,7 @@ struct type_var : public type {
: name(std::move(n)) {} : name(std::move(n)) {}
void print(const type_mgr& mgr, std::ostream& to) const; void print(const type_mgr& mgr, std::ostream& to) const;
bool is_arrow(const type_mgr& mgr) const;
}; };
struct type_base : public type { struct type_base : public type {
@@ -65,6 +67,7 @@ struct type_arr : public type {
: left(std::move(l)), right(std::move(r)) {} : left(std::move(l)), right(std::move(r)) {}
void print(const type_mgr& mgr, std::ostream& to) const; void print(const type_mgr& mgr, std::ostream& to) const;
bool is_arrow(const type_mgr& mgr) const;
}; };
struct type_app : public type { struct type_app : public type {
@@ -92,4 +95,5 @@ struct type_mgr {
type_ptr resolve(type_ptr t, type_var*& var) const; type_ptr resolve(type_ptr t, type_var*& var) const;
void bind(const std::string& s, type_ptr t); void bind(const std::string& s, type_ptr t);
void find_free(const type_ptr& t, std::set<std::string>& into) const; void find_free(const type_ptr& t, std::set<std::string>& into) const;
void find_free(const type_scheme_ptr& t, std::set<std::string>& into) const;
}; };

View File

@@ -1,13 +1,49 @@
#include "type_env.hpp" #include "type_env.hpp"
#include "type.hpp" #include "type.hpp"
void type_env::find_free(const type_mgr& mgr, std::set<std::string>& into) const {
if(parent != nullptr) parent->find_free(mgr, into);
for(auto& binding : names) {
mgr.find_free(binding.second.type, into);
}
}
void type_env::find_free_except(const type_mgr& mgr, const group& avoid,
std::set<std::string>& into) const {
if(parent != nullptr) parent->find_free(mgr, into);
for(auto& binding : names) {
if(avoid.members.find(binding.first) != avoid.members.end()) continue;
mgr.find_free(binding.second.type, into);
}
}
type_scheme_ptr type_env::lookup(const std::string& name) const { type_scheme_ptr type_env::lookup(const std::string& name) const {
auto it = names.find(name); auto it = names.find(name);
if(it != names.end()) return it->second; if(it != names.end()) return it->second.type;
if(parent) return parent->lookup(name); if(parent) return parent->lookup(name);
return nullptr; return nullptr;
} }
bool type_env::is_global(const std::string& name) const {
auto it = names.find(name);
if(it != names.end()) return it->second.vis == visibility::global;
if(parent) return parent->is_global(name);
return false;
}
void type_env::set_mangled_name(const std::string& name, const std::string& mangled) {
auto it = names.find(name);
if(it != names.end()) it->second.mangled_name = mangled;
}
const std::string& type_env::get_mangled_name(const std::string& name) const {
auto it = names.find(name);
if(it != names.end())
return (it->second.mangled_name != "") ? it->second.mangled_name : name;
if(parent) return parent->get_mangled_name(name);
return name;
}
type_ptr type_env::lookup_type(const std::string& name) const { type_ptr type_env::lookup_type(const std::string& name) const {
auto it = type_names.find(name); auto it = type_names.find(name);
if(it != type_names.end()) return it->second; if(it != type_names.end()) return it->second;
@@ -15,12 +51,13 @@ type_ptr type_env::lookup_type(const std::string& name) const {
return nullptr; return nullptr;
} }
void type_env::bind(const std::string& name, type_ptr t) { void type_env::bind(const std::string& name, type_ptr t, visibility v) {
names[name] = type_scheme_ptr(new type_scheme(t)); type_scheme_ptr new_scheme(new type_scheme(std::move(t)));
names[name] = variable_data(std::move(new_scheme), v, "");
} }
void type_env::bind(const std::string& name, type_scheme_ptr t) { void type_env::bind(const std::string& name, type_scheme_ptr t, visibility v) {
names[name] = t; names[name] = variable_data(std::move(t), v, "");
} }
void type_env::bind_type(const std::string& type_name, type_ptr t) { void type_env::bind_type(const std::string& type_name, type_ptr t) {
@@ -28,15 +65,18 @@ void type_env::bind_type(const std::string& type_name, type_ptr t) {
type_names[type_name] = t; type_names[type_name] = t;
} }
void type_env::generalize(const std::string& name, type_mgr& mgr) { void type_env::generalize(const std::string& name, const group& grp, type_mgr& mgr) {
auto names_it = names.find(name); auto names_it = names.find(name);
if(names_it == names.end()) throw 0; if(names_it == names.end()) throw 0;
if(names_it->second->forall.size() > 0) throw 0; if(names_it->second.type->forall.size() > 0) throw 0;
std::set<std::string> free_variables; std::set<std::string> free_in_type;
mgr.find_free(names_it->second->monotype, free_variables); std::set<std::string> free_in_env;
for(auto& free : free_variables) { mgr.find_free(names_it->second.type->monotype, free_in_type);
names_it->second->forall.push_back(free); find_free_except(mgr, grp, free_in_env);
for(auto& free : free_in_type) {
if(free_in_env.find(free) != free_in_env.end()) continue;
names_it->second.type->forall.push_back(free);
} }
} }

View File

@@ -1,25 +1,48 @@
#pragma once #pragma once
#include <map> #include <map>
#include <string> #include <string>
#include <set>
#include "graph.hpp"
#include "type.hpp" #include "type.hpp"
struct type_env; struct type_env;
using type_env_ptr = std::shared_ptr<type_env>; using type_env_ptr = std::shared_ptr<type_env>;
enum class visibility { global,local };
struct type_env { struct type_env {
struct variable_data {
type_scheme_ptr type;
visibility vis;
std::string mangled_name;
variable_data()
: variable_data(nullptr, visibility::local, "") {}
variable_data(type_scheme_ptr t, visibility v, std::string n)
: type(std::move(t)), vis(v), mangled_name(std::move(n)) {}
};
type_env_ptr parent; type_env_ptr parent;
std::map<std::string, type_scheme_ptr> names; std::map<std::string, variable_data> names;
std::map<std::string, type_ptr> type_names; std::map<std::string, type_ptr> type_names;
type_env(type_env_ptr p) : parent(std::move(p)) {} type_env(type_env_ptr p) : parent(std::move(p)) {}
type_env() : type_env(nullptr) {} type_env() : type_env(nullptr) {}
void find_free(const type_mgr& mgr, std::set<std::string>& into) const;
void find_free_except(const type_mgr& mgr, const group& avoid,
std::set<std::string>& into) const;
type_scheme_ptr lookup(const std::string& name) const; type_scheme_ptr lookup(const std::string& name) const;
bool is_global(const std::string& name) const;
void set_mangled_name(const std::string& name, const std::string& mangled);
const std::string& get_mangled_name(const std::string& name) const;
type_ptr lookup_type(const std::string& name) const; type_ptr lookup_type(const std::string& name) const;
void bind(const std::string& name, type_ptr t); void bind(const std::string& name, type_ptr t,
void bind(const std::string& name, type_scheme_ptr t); visibility v = visibility::local);
void bind(const std::string& name, type_scheme_ptr t,
visibility v = visibility::local);
void bind_type(const std::string& type_name, type_ptr t); void bind_type(const std::string& type_name, type_ptr t);
void generalize(const std::string& name, type_mgr& mgr); void generalize(const std::string& name, const group& grp, type_mgr& mgr);
}; };

View File

@@ -0,0 +1,53 @@
cmake_minimum_required(VERSION 3.1)
project(compiler)
# We want C++17 for std::optional
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Find all the required packages
find_package(BISON)
find_package(FLEX)
find_package(LLVM REQUIRED CONFIG)
# Set up the flex and bison targets
bison_target(parser
${CMAKE_CURRENT_SOURCE_DIR}/parser.y
${CMAKE_CURRENT_BINARY_DIR}/parser.cpp
COMPILE_FLAGS "-d")
flex_target(scanner
${CMAKE_CURRENT_SOURCE_DIR}/scanner.l
${CMAKE_CURRENT_BINARY_DIR}/scanner.cpp)
add_flex_bison_dependency(scanner parser)
# Find all the relevant LLVM components
llvm_map_components_to_libnames(LLVM_LIBS core x86asmparser x86codegen)
# Create compiler executable
add_executable(compiler
definition.cpp definition.hpp
parsed_type.cpp parsed_type.hpp
ast.cpp ast.hpp
llvm_context.cpp llvm_context.hpp
type_env.cpp type_env.hpp
env.cpp env.hpp
type.cpp type.hpp
error.cpp error.hpp
binop.cpp binop.hpp
instruction.cpp instruction.hpp
graph.cpp graph.hpp
global_scope.cpp global_scope.hpp
parse_driver.cpp parse_driver.hpp
mangler.cpp mangler.hpp
compiler.cpp compiler.hpp
${BISON_parser_OUTPUTS}
${FLEX_scanner_OUTPUTS}
main.cpp
)
# Configure compiler executable
target_include_directories(compiler PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_include_directories(compiler PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
target_include_directories(compiler PUBLIC ${LLVM_INCLUDE_DIRS})
target_compile_definitions(compiler PUBLIC ${LLVM_DEFINITIONS})
target_link_libraries(compiler ${LLVM_LIBS})

454
code/compiler/13/ast.cpp Normal file
View File

@@ -0,0 +1,454 @@
#include "ast.hpp"
#include <ostream>
#include <type_traits>
#include "binop.hpp"
#include "error.hpp"
#include "instruction.hpp"
#include "type.hpp"
#include "type_env.hpp"
#include "env.hpp"
static void print_indent(int n, std::ostream& to) {
while(n--) to << " ";
}
void ast_int::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "INT: " << value << std::endl;
}
void ast_int::find_free(std::set<std::string>& into) {
}
type_ptr ast_int::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
return type_ptr(new type_app(env->lookup_type("Int")));
}
void ast_int::translate(global_scope& scope) {
}
void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushint(value)));
}
void ast_lid::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "LID: " << id << std::endl;
}
void ast_lid::find_free(std::set<std::string>& into) {
into.insert(id);
}
type_ptr ast_lid::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
type_scheme_ptr lid_type = env->lookup(id);
if(!lid_type)
throw type_error("unknown identifier " + id, loc);
return lid_type->instantiate(mgr);
}
void ast_lid::translate(global_scope& scope) {
}
void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(
(this->env->is_global(id)) ?
(instruction*) new instruction_pushglobal(this->env->get_mangled_name(id)) :
(instruction*) new instruction_push(env->get_offset(id))));
}
void ast_uid::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "UID: " << id << std::endl;
}
void ast_uid::find_free(std::set<std::string>& into) {
}
type_ptr ast_uid::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
type_scheme_ptr uid_type = env->lookup(id);
if(!uid_type)
throw type_error("unknown constructor " + id, loc);
return uid_type->instantiate(mgr);
}
void ast_uid::translate(global_scope& scope) {
}
void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(
new instruction_pushglobal(this->env->get_mangled_name(id))));
}
void ast_binop::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "BINOP: " << op_name(op) << std::endl;
left->print(indent + 1, to);
right->print(indent + 1, to);
}
void ast_binop::find_free(std::set<std::string>& into) {
left->find_free(into);
right->find_free(into);
}
type_ptr ast_binop::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
type_ptr ltype = left->typecheck(mgr, env);
type_ptr rtype = right->typecheck(mgr, env);
type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr);
if(!ftype) throw type_error("unknown binary operator " + op_name(op), loc);
// For better type errors, we first require binary function,
// and only then unify each argument. This way, we can
// precisely point out which argument is "wrong".
type_ptr return_type = mgr.new_type();
type_ptr second_type = mgr.new_type();
type_ptr first_type = mgr.new_type();
type_ptr arrow_one = type_ptr(new type_arr(second_type, return_type));
type_ptr arrow_two = type_ptr(new type_arr(first_type, arrow_one));
mgr.unify(ftype, arrow_two, loc);
mgr.unify(first_type, ltype, left->loc);
mgr.unify(second_type, rtype, right->loc);
return return_type;
}
void ast_binop::translate(global_scope& scope) {
left->translate(scope);
right->translate(scope);
}
void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into);
auto mangled_name = this->env->get_mangled_name(op_name(op));
into.push_back(instruction_ptr(new instruction_pushglobal(mangled_name)));
into.push_back(instruction_ptr(new instruction_mkapp()));
into.push_back(instruction_ptr(new instruction_mkapp()));
}
void ast_app::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "APP:" << std::endl;
left->print(indent + 1, to);
right->print(indent + 1, to);
}
void ast_app::find_free(std::set<std::string>& into) {
left->find_free(into);
right->find_free(into);
}
type_ptr ast_app::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
type_ptr ltype = left->typecheck(mgr, env);
type_ptr rtype = right->typecheck(mgr, env);
type_ptr return_type = mgr.new_type();
type_ptr arrow = type_ptr(new type_arr(rtype, return_type));
mgr.unify(arrow, ltype, left->loc);
return return_type;
}
void ast_app::translate(global_scope& scope) {
left->translate(scope);
right->translate(scope);
}
void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into);
into.push_back(instruction_ptr(new instruction_mkapp()));
}
void ast_case::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "CASE: " << std::endl;
for(auto& branch : branches) {
print_indent(indent + 1, to);
branch->pat->print(to);
to << std::endl;
branch->expr->print(indent + 2, to);
}
}
void ast_case::find_free(std::set<std::string>& into) {
of->find_free(into);
for(auto& branch : branches) {
std::set<std::string> free_in_branch;
std::set<std::string> pattern_variables;
branch->pat->find_variables(pattern_variables);
branch->expr->find_free(free_in_branch);
for(auto& free : free_in_branch) {
if(pattern_variables.find(free) == pattern_variables.end())
into.insert(free);
}
}
}
type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
type_var* var;
type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var);
type_ptr branch_type = mgr.new_type();
for(auto& branch : branches) {
type_env_ptr new_env = type_scope(env);
branch->pat->typecheck(case_type, mgr, new_env);
type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env);
mgr.unify(curr_branch_type, branch_type, branch->expr->loc);
}
input_type = mgr.resolve(case_type, var);
type_app* app_type;
if(!(app_type = dynamic_cast<type_app*>(input_type.get())) ||
!dynamic_cast<type_data*>(app_type->constructor.get())) {
throw type_error("attempting case analysis of non-data type");
}
return branch_type;
}
void ast_case::translate(global_scope& scope) {
of->translate(scope);
for(auto& branch : branches) {
branch->expr->translate(scope);
}
}
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
type_app* app_type = dynamic_cast<type_app*>(input_type.get());
type_data* type = dynamic_cast<type_data*>(app_type->constructor.get());
of->compile(env, into);
into.push_back(instruction_ptr(new instruction_eval()));
instruction_jump* jump_instruction = new instruction_jump();
into.push_back(instruction_ptr(jump_instruction));
for(auto& branch : branches) {
std::vector<instruction_ptr> branch_instructions;
pattern_var* vpat;
pattern_constr* cpat;
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions);
for(auto& constr_pair : type->constructors) {
if(jump_instruction->tag_mappings.find(constr_pair.second.tag) !=
jump_instruction->tag_mappings.end())
break;
jump_instruction->tag_mappings[constr_pair.second.tag] =
jump_instruction->branches.size();
}
jump_instruction->branches.push_back(std::move(branch_instructions));
} else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) {
env_ptr new_env = env;
for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env));
}
branch_instructions.push_back(instruction_ptr(new instruction_split(
cpat->params.size())));
branch->expr->compile(new_env, branch_instructions);
branch_instructions.push_back(instruction_ptr(new instruction_slide(
cpat->params.size())));
int new_tag = type->constructors[cpat->constr].tag;
if(jump_instruction->tag_mappings.find(new_tag) !=
jump_instruction->tag_mappings.end())
throw type_error("technically not a type error: duplicate pattern");
jump_instruction->tag_mappings[new_tag] =
jump_instruction->branches.size();
jump_instruction->branches.push_back(std::move(branch_instructions));
}
}
for(auto& constr_pair : type->constructors) {
if(jump_instruction->tag_mappings.find(constr_pair.second.tag) ==
jump_instruction->tag_mappings.end())
throw type_error("non-total pattern");
}
}
void ast_let::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "LET: " << std::endl;
in->print(indent + 1, to);
}
void ast_let::find_free(std::set<std::string>& into) {
definitions.find_free(into);
std::set<std::string> all_free;
in->find_free(all_free);
for(auto& free_var : all_free) {
if(definitions.defs_defn.find(free_var) == definitions.defs_defn.end())
into.insert(free_var);
}
}
type_ptr ast_let::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
definitions.typecheck(mgr, env);
return in->typecheck(mgr, definitions.env);
}
void ast_let::translate(global_scope& scope) {
for(auto& def : definitions.defs_data) {
def.second->into_globals(scope);
}
for(auto& def : definitions.defs_defn) {
size_t original_params = def.second->params.size();
std::string original_name = def.second->name;
auto& global_definition = def.second->into_global(scope);
size_t captured = global_definition.params.size() - original_params;
type_env_ptr mangled_env = type_scope(env);
mangled_env->bind(def.first, env->lookup(def.first), visibility::global);
mangled_env->set_mangled_name(def.first, global_definition.name);
ast_ptr global_app(new ast_lid(original_name));
global_app->env = mangled_env;
for(auto& param : global_definition.params) {
if(!(captured--)) break;
ast_ptr new_arg(new ast_lid(param));
new_arg->env = env;
global_app = ast_ptr(new ast_app(std::move(global_app), std::move(new_arg)));
global_app->env = env;
}
translated_definitions.push_back({ def.first, std::move(global_app) });
}
in->translate(scope);
}
void ast_let::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_alloc(translated_definitions.size())));
env_ptr new_env = env;
for(auto& def : translated_definitions) {
new_env = env_ptr(new env_var(def.first, std::move(new_env)));
}
int offset = translated_definitions.size() - 1;
for(auto& def : translated_definitions) {
def.second->compile(new_env, into);
into.push_back(instruction_ptr(new instruction_update(offset--)));
}
in->compile(new_env, into);
into.push_back(instruction_ptr(new instruction_slide(translated_definitions.size())));
}
void ast_lambda::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "LAMBDA";
for(auto& param : params) {
to << " " << param;
}
to << std::endl;
body->print(indent+1, to);
}
void ast_lambda::find_free(std::set<std::string>& into) {
body->find_free(free_variables);
for(auto& param : params) {
free_variables.erase(param);
}
into.insert(free_variables.begin(), free_variables.end());
}
type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = env;
var_env = type_scope(env);
type_ptr return_type = mgr.new_type();
type_ptr full_type = return_type;
for(auto it = params.rbegin(); it != params.rend(); it++) {
type_ptr param_type = mgr.new_type();
var_env->bind(*it, param_type);
full_type = type_ptr(new type_arr(std::move(param_type), full_type));
}
mgr.unify(return_type, body->typecheck(mgr, var_env), body->loc);
return full_type;
}
void ast_lambda::translate(global_scope& scope) {
std::vector<std::string> function_params;
for(auto& free_variable : free_variables) {
if(env->is_global(free_variable)) continue;
function_params.push_back(free_variable);
}
size_t captured_count = function_params.size();
function_params.insert(function_params.end(), params.begin(), params.end());
auto& new_function = scope.add_function("lambda", std::move(function_params), std::move(body));
type_env_ptr mangled_env = type_scope(env);
mangled_env->bind("lambda", type_scheme_ptr(nullptr), visibility::global);
mangled_env->set_mangled_name("lambda", new_function.name);
ast_ptr new_application = ast_ptr(new ast_lid("lambda"));
new_application->env = mangled_env;
for(auto& param : new_function.params) {
if(!(captured_count--)) break;
ast_ptr new_arg = ast_ptr(new ast_lid(param));
new_arg->env = env;
new_application = ast_ptr(new ast_app(std::move(new_application), std::move(new_arg)));
new_application->env = env;
}
translated = std::move(new_application);
}
void ast_lambda::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
translated->compile(env, into);
}
void pattern_var::print(std::ostream& to) const {
to << var;
}
void pattern_var::find_variables(std::set<std::string>& into) const {
into.insert(var);
}
void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
env->bind(var, t);
}
void pattern_constr::print(std::ostream& to) const {
to << constr;
for(auto& param : params) {
to << " " << param;
}
}
void pattern_constr::find_variables(std::set<std::string>& into) const {
into.insert(params.begin(), params.end());
}
void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
type_scheme_ptr constructor_type_scheme = env->lookup(constr);
if(!constructor_type_scheme) {
throw type_error("pattern using unknown constructor " + constr, loc);
}
type_ptr constructor_type = constructor_type_scheme->instantiate(mgr);
for(auto& param : params) {
type_arr* arr = dynamic_cast<type_arr*>(constructor_type.get());
if(!arr) throw type_error("too many parameters in constructor pattern", loc);
env->bind(param, arr->left);
constructor_type = arr->right;
}
mgr.unify(t, constructor_type, loc);
}

195
code/compiler/13/ast.hpp Normal file
View File

@@ -0,0 +1,195 @@
#pragma once
#include <memory>
#include <vector>
#include <set>
#include "type.hpp"
#include "type_env.hpp"
#include "binop.hpp"
#include "instruction.hpp"
#include "env.hpp"
#include "definition.hpp"
#include "location.hh"
#include "global_scope.hpp"
struct ast {
type_env_ptr env;
yy::location loc;
ast(yy::location l) : env(nullptr), loc(std::move(l)) {}
virtual ~ast() = default;
virtual void print(int indent, std::ostream& to) const = 0;
virtual void find_free(std::set<std::string>& into) = 0;
virtual type_ptr typecheck(type_mgr& mgr, type_env_ptr& env) = 0;
virtual void translate(global_scope& scope) = 0;
virtual void compile(const env_ptr& env,
std::vector<instruction_ptr>& into) const = 0;
};
using ast_ptr = std::unique_ptr<ast>;
struct pattern {
yy::location loc;
pattern(yy::location l) : loc(std::move(l)) {}
virtual ~pattern() = default;
virtual void print(std::ostream& to) const = 0;
virtual void find_variables(std::set<std::string>& into) const = 0;
virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const = 0;
};
using pattern_ptr = std::unique_ptr<pattern>;
struct branch {
pattern_ptr pat;
ast_ptr expr;
branch(pattern_ptr p, ast_ptr a)
: pat(std::move(p)), expr(std::move(a)) {}
};
using branch_ptr = std::unique_ptr<branch>;
struct ast_int : public ast {
int value;
explicit ast_int(int v, yy::location l = yy::location())
: ast(std::move(l)), value(v) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_lid : public ast {
std::string id;
explicit ast_lid(std::string i, yy::location l = yy::location())
: ast(std::move(l)), id(std::move(i)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_uid : public ast {
std::string id;
explicit ast_uid(std::string i, yy::location l = yy::location())
: ast(std::move(l)), id(std::move(i)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_binop : public ast {
binop op;
ast_ptr left;
ast_ptr right;
ast_binop(binop o, ast_ptr l, ast_ptr r, yy::location lc = yy::location())
: ast(std::move(lc)), op(o), left(std::move(l)), right(std::move(r)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_app : public ast {
ast_ptr left;
ast_ptr right;
ast_app(ast_ptr l, ast_ptr r, yy::location lc = yy::location())
: ast(std::move(lc)), left(std::move(l)), right(std::move(r)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_case : public ast {
ast_ptr of;
type_ptr input_type;
std::vector<branch_ptr> branches;
ast_case(ast_ptr o, std::vector<branch_ptr> b, yy::location l = yy::location())
: ast(std::move(l)), of(std::move(o)), branches(std::move(b)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_let : public ast {
using basic_definition = std::pair<std::string, ast_ptr>;
definition_group definitions;
ast_ptr in;
std::vector<basic_definition> translated_definitions;
ast_let(definition_group g, ast_ptr i, yy::location l = yy::location())
: ast(std::move(l)), definitions(std::move(g)), in(std::move(i)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct ast_lambda : public ast {
std::vector<std::string> params;
ast_ptr body;
type_env_ptr var_env;
std::set<std::string> free_variables;
ast_ptr translated;
ast_lambda(std::vector<std::string> ps, ast_ptr b, yy::location l = yy::location())
: ast(std::move(l)), params(std::move(ps)), body(std::move(b)) {}
void print(int indent, std::ostream& to) const;
void find_free(std::set<std::string>& into);
type_ptr typecheck(type_mgr& mgr, type_env_ptr& env);
void translate(global_scope& scope);
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
};
struct pattern_var : public pattern {
std::string var;
pattern_var(std::string v, yy::location l = yy::location())
: pattern(std::move(l)), var(std::move(v)) {}
void print(std::ostream &to) const;
void find_variables(std::set<std::string>& into) const;
void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const;
};
struct pattern_constr : public pattern {
std::string constr;
std::vector<std::string> params;
pattern_constr(std::string c, std::vector<std::string> p, yy::location l = yy::location())
: pattern(std::move(l)), constr(std::move(c)), params(std::move(p)) {}
void print(std::ostream &to) const;
void find_variables(std::set<std::string>& into) const;
virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const;
};

View File

@@ -0,0 +1,21 @@
#include "binop.hpp"
std::string op_name(binop op) {
switch(op) {
case PLUS: return "+";
case MINUS: return "-";
case TIMES: return "*";
case DIVIDE: return "/";
}
return "??";
}
std::string op_action(binop op) {
switch(op) {
case PLUS: return "plus";
case MINUS: return "minus";
case TIMES: return "times";
case DIVIDE: return "divide";
}
return "??";
}

View File

@@ -0,0 +1,17 @@
#pragma once
#include <array>
#include <string>
enum binop {
PLUS,
MINUS,
TIMES,
DIVIDE
};
constexpr binop all_binops[] = {
PLUS, MINUS, TIMES, DIVIDE
};
std::string op_name(binop op);
std::string op_action(binop op);

View File

@@ -0,0 +1,153 @@
#include "compiler.hpp"
#include "binop.hpp"
#include "error.hpp"
#include "global_scope.hpp"
#include "parse_driver.hpp"
#include "type.hpp"
#include "type_env.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Target/TargetMachine.h"
void compiler::add_default_types() {
global_env->bind_type("Int", type_ptr(new type_base("Int")));
}
void compiler::add_binop_type(binop op, type_ptr type) {
auto name = mng.new_mangled_name(op_action(op));
global_env->bind(op_name(op), std::move(type), visibility::global);
global_env->set_mangled_name(op_name(op), name);
}
void compiler::add_default_function_types() {
type_ptr int_type = global_env->lookup_type("Int");
assert(int_type != nullptr);
type_ptr int_type_app = type_ptr(new type_app(int_type));
type_ptr closed_int_op_type(
new type_arr(int_type_app, type_ptr(new type_arr(int_type_app, int_type_app))));
constexpr binop closed_ops[] = { PLUS, MINUS, TIMES, DIVIDE };
for(auto& op : closed_ops) add_binop_type(op, closed_int_op_type);
}
void compiler::parse() {
if(!driver())
throw compiler_error("failed to open file");
}
void compiler::typecheck() {
std::set<std::string> free_variables;
global_defs.find_free(free_variables);
global_defs.typecheck(type_m, global_env);
}
void compiler::translate() {
for(auto& data : global_defs.defs_data) {
data.second->into_globals(global_scp);
}
for(auto& defn : global_defs.defs_defn) {
auto& function = defn.second->into_global(global_scp);
defn.second->env->set_mangled_name(defn.first, function.name);
}
}
void compiler::compile() {
global_scp.compile();
}
void compiler::create_llvm_binop(binop op) {
auto new_function =
ctx.create_custom_function(global_env->get_mangled_name(op_name(op)), 2);
std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_push(1)));
instructions.push_back(instruction_ptr(new instruction_eval()));
instructions.push_back(instruction_ptr(new instruction_push(1)));
instructions.push_back(instruction_ptr(new instruction_eval()));
instructions.push_back(instruction_ptr(new instruction_binop(op)));
instructions.push_back(instruction_ptr(new instruction_update(2)));
instructions.push_back(instruction_ptr(new instruction_pop(2)));
ctx.get_builder().SetInsertPoint(&new_function->getEntryBlock());
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function);
}
ctx.get_builder().CreateRetVoid();
}
void compiler::generate_llvm() {
for(auto op : all_binops) {
create_llvm_binop(op);
}
global_scp.generate_llvm(ctx);
}
void compiler::output_llvm(const std::string& into) {
std::string targetTriple = llvm::sys::getDefaultTargetTriple();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser();
llvm::InitializeNativeTargetAsmPrinter();
std::string error;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(targetTriple, error);
if (!target) {
std::cerr << error << std::endl;
} else {
std::string cpu = "generic";
std::string features = "";
llvm::TargetOptions options;
std::unique_ptr<llvm::TargetMachine> targetMachine(
target->createTargetMachine(targetTriple, cpu, features,
options, llvm::Optional<llvm::Reloc::Model>()));
ctx.get_module().setDataLayout(targetMachine->createDataLayout());
ctx.get_module().setTargetTriple(targetTriple);
std::error_code ec;
llvm::raw_fd_ostream file(into, ec, llvm::sys::fs::F_None);
if (ec) {
throw compiler_error("failed to open object file for writing");
} else {
llvm::CodeGenFileType type = llvm::CGFT_ObjectFile;
llvm::legacy::PassManager pm;
if (targetMachine->addPassesToEmitFile(pm, file, NULL, type)) {
throw compiler_error("failed to add passes to pass manager");
} else {
pm.run(ctx.get_module());
file.close();
}
}
}
}
compiler::compiler(const std::string& filename)
: file_m(), global_defs(), driver(file_m, global_defs, filename),
global_env(new type_env), type_m(), mng(), global_scp(mng), ctx() {
add_default_types();
add_default_function_types();
}
void compiler::operator()(const std::string& into) {
parse();
typecheck();
translate();
compile();
generate_llvm();
output_llvm(into);
}
file_mgr& compiler::get_file_manager() {
return file_m;
}
type_mgr& compiler::get_type_manager() {
return type_m;
}

View File

@@ -0,0 +1,37 @@
#pragma once
#include "binop.hpp"
#include "parse_driver.hpp"
#include "definition.hpp"
#include "type_env.hpp"
#include "type.hpp"
#include "global_scope.hpp"
#include "mangler.hpp"
#include "llvm_context.hpp"
class compiler {
private:
file_mgr file_m;
definition_group global_defs;
parse_driver driver;
type_env_ptr global_env;
type_mgr type_m;
mangler mng;
global_scope global_scp;
llvm_context ctx;
void add_default_types();
void add_binop_type(binop op, type_ptr type);
void add_default_function_types();
void parse();
void typecheck();
void translate();
void compile();
void create_llvm_binop(binop op);
void generate_llvm();
void output_llvm(const std::string& into);
public:
compiler(const std::string& filename);
void operator()(const std::string& into);
file_mgr& get_file_manager();
type_mgr& get_type_manager();
};

View File

@@ -0,0 +1,148 @@
#include "definition.hpp"
#include <cassert>
#include "error.hpp"
#include "ast.hpp"
#include "instruction.hpp"
#include "llvm_context.hpp"
#include "type.hpp"
#include "type_env.hpp"
#include "graph.hpp"
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Type.h>
void definition_defn::find_free() {
body->find_free(free_variables);
for(auto& param : params) {
free_variables.erase(param);
}
}
void definition_defn::insert_types(type_mgr& mgr, type_env_ptr& env, visibility v) {
this->env = env;
var_env = type_scope(env);
return_type = mgr.new_type();
full_type = return_type;
for(auto it = params.rbegin(); it != params.rend(); it++) {
type_ptr param_type = mgr.new_type();
full_type = type_ptr(new type_arr(param_type, full_type));
var_env->bind(*it, param_type);
}
env->bind(name, full_type, v);
}
void definition_defn::typecheck(type_mgr& mgr) {
type_ptr body_type = body->typecheck(mgr, var_env);
mgr.unify(return_type, body_type);
}
global_function& definition_defn::into_global(global_scope& scope) {
std::vector<std::string> all_params;
for(auto& free : free_variables) {
if(env->is_global(free)) continue;
all_params.push_back(free);
}
all_params.insert(all_params.end(), params.begin(), params.end());
body->translate(scope);
return scope.add_function(name, std::move(all_params), std::move(body));
}
void definition_data::insert_types(type_env_ptr& env) {
this->env = env;
env->bind_type(name, type_ptr(new type_data(name, vars.size())));
}
void definition_data::insert_constructors() const {
type_ptr this_type_ptr = env->lookup_type(name);
type_data* this_type = static_cast<type_data*>(this_type_ptr.get());
int next_tag = 0;
std::set<std::string> var_set;
type_app* return_app = new type_app(std::move(this_type_ptr));
type_ptr return_type(return_app);
for(auto& var : vars) {
if(var_set.find(var) != var_set.end())
throw compiler_error(
"type variable " + var +
" used twice in data type definition.", loc);
var_set.insert(var);
return_app->arguments.push_back(type_ptr(new type_var(var)));
}
for(auto& constructor : constructors) {
constructor->tag = next_tag;
this_type->constructors[constructor->name] = { next_tag++ };
type_ptr full_type = return_type;
for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) {
type_ptr type = (*it)->to_type(var_set, env);
full_type = type_ptr(new type_arr(type, full_type));
}
type_scheme_ptr full_scheme(new type_scheme(std::move(full_type)));
full_scheme->forall.insert(full_scheme->forall.begin(), vars.begin(), vars.end());
env->bind(constructor->name, full_scheme, visibility::global);
}
}
void definition_data::into_globals(global_scope& scope) {
for(auto& constructor : constructors) {
global_constructor& c = scope.add_constructor(
constructor->name, constructor->tag, constructor->types.size());
env->set_mangled_name(constructor->name, c.name);
}
}
void definition_group::find_free(std::set<std::string>& into) {
for(auto& def_pair : defs_defn) {
def_pair.second->find_free();
for(auto& free_var : def_pair.second->free_variables) {
if(defs_defn.find(free_var) == defs_defn.end()) {
into.insert(free_var);
} else {
def_pair.second->nearby_variables.insert(free_var);
}
}
}
}
void definition_group::typecheck(type_mgr& mgr, type_env_ptr& env) {
this->env = type_scope(env);
for(auto& def_data : defs_data) {
def_data.second->insert_types(this->env);
}
for(auto& def_data : defs_data) {
def_data.second->insert_constructors();
}
function_graph dependency_graph;
for(auto& def_defn : defs_defn) {
def_defn.second->find_free();
dependency_graph.add_function(def_defn.second->name);
for(auto& dependency : def_defn.second->nearby_variables) {
assert(defs_defn.find(dependency) != defs_defn.end());
dependency_graph.add_edge(def_defn.second->name, dependency);
}
}
std::vector<group_ptr> groups = dependency_graph.compute_order();
for(auto it = groups.rbegin(); it != groups.rend(); it++) {
auto& group = *it;
for(auto& def_defnn_name : group->members) {
auto& def_defn = defs_defn.find(def_defnn_name)->second;
def_defn->insert_types(mgr, this->env, vis);
}
for(auto& def_defnn_name : group->members) {
auto& def_defn = defs_defn.find(def_defnn_name)->second;
def_defn->typecheck(mgr);
}
for(auto& def_defnn_name : group->members) {
this->env->generalize(def_defnn_name, *group, mgr);
}
}
}

View File

@@ -0,0 +1,91 @@
#pragma once
#include <memory>
#include <vector>
#include <map>
#include <set>
#include "instruction.hpp"
#include "llvm_context.hpp"
#include "parsed_type.hpp"
#include "type_env.hpp"
#include "location.hh"
#include "global_scope.hpp"
struct ast;
using ast_ptr = std::unique_ptr<ast>;
struct constructor {
std::string name;
std::vector<parsed_type_ptr> types;
int8_t tag;
constructor(std::string n, std::vector<parsed_type_ptr> ts)
: name(std::move(n)), types(std::move(ts)) {}
};
using constructor_ptr = std::unique_ptr<constructor>;
struct definition_defn {
std::string name;
std::vector<std::string> params;
ast_ptr body;
yy::location loc;
type_env_ptr env;
type_env_ptr var_env;
std::set<std::string> free_variables;
std::set<std::string> nearby_variables;
type_ptr full_type;
type_ptr return_type;
definition_defn(
std::string n,
std::vector<std::string> p,
ast_ptr b,
yy::location l = yy::location())
: name(std::move(n)), params(std::move(p)), body(std::move(b)), loc(std::move(l)) {
}
void find_free();
void insert_types(type_mgr& mgr, type_env_ptr& env, visibility v);
void typecheck(type_mgr& mgr);
global_function& into_global(global_scope& scope);
};
using definition_defn_ptr = std::unique_ptr<definition_defn>;
struct definition_data {
std::string name;
std::vector<std::string> vars;
std::vector<constructor_ptr> constructors;
yy::location loc;
type_env_ptr env;
definition_data(
std::string n,
std::vector<std::string> vs,
std::vector<constructor_ptr> cs,
yy::location l = yy::location())
: name(std::move(n)), vars(std::move(vs)), constructors(std::move(cs)), loc(std::move(l)) {}
void insert_types(type_env_ptr& env);
void insert_constructors() const;
void into_globals(global_scope& scope);
};
using definition_data_ptr = std::unique_ptr<definition_data>;
struct definition_group {
std::map<std::string, definition_data_ptr> defs_data;
std::map<std::string, definition_defn_ptr> defs_defn;
visibility vis;
type_env_ptr env;
definition_group(visibility v = visibility::local) : vis(v) {}
void find_free(std::set<std::string>& into);
void typecheck(type_mgr& mgr, type_env_ptr& env);
};

24
code/compiler/13/env.cpp Normal file
View File

@@ -0,0 +1,24 @@
#include "env.hpp"
#include <cassert>
int env_var::get_offset(const std::string& name) const {
if(name == this->name) return 0;
assert(parent != nullptr);
return parent->get_offset(name) + 1;
}
bool env_var::has_variable(const std::string& name) const {
if(name == this->name) return true;
if(parent) return parent->has_variable(name);
return false;
}
int env_offset::get_offset(const std::string& name) const {
assert(parent != nullptr);
return parent->get_offset(name) + offset;
}
bool env_offset::has_variable(const std::string& name) const {
if(parent) return parent->has_variable(name);
return false;
}

39
code/compiler/13/env.hpp Normal file
View File

@@ -0,0 +1,39 @@
#pragma once
#include <memory>
#include <string>
class env {
public:
virtual ~env() = default;
virtual int get_offset(const std::string& name) const = 0;
virtual bool has_variable(const std::string& name) const = 0;
};
using env_ptr = std::shared_ptr<env>;
class env_var : public env {
private:
std::string name;
env_ptr parent;
public:
env_var(std::string n, env_ptr p)
: name(std::move(n)), parent(std::move(p)) {}
int get_offset(const std::string& name) const;
bool has_variable(const std::string& name) const;
};
class env_offset : public env {
private:
int offset;
env_ptr parent;
public:
env_offset(int o, env_ptr p)
: offset(o), parent(std::move(p)) {}
int get_offset(const std::string& name) const;
bool has_variable(const std::string& name) const;
};

View File

@@ -0,0 +1,41 @@
#include "error.hpp"
const char* compiler_error::what() const noexcept {
return "an error occured while compiling the program";
}
void compiler_error::print_about(std::ostream& to) {
to << what() << ": ";
to << description << std::endl;
}
void compiler_error::print_location(std::ostream& to, file_mgr& fm, bool highlight) {
if(!loc) return;
to << "occuring on line " << loc->begin.line << ":" << std::endl;
fm.print_location(to, *loc, highlight);
}
void compiler_error::pretty_print(std::ostream& to, file_mgr& fm) {
print_about(to);
print_location(to, fm);
}
const char* type_error::what() const noexcept {
return "an error occured while checking the types of the program";
}
void type_error::pretty_print(std::ostream& to, file_mgr& fm) {
print_about(to);
print_location(to, fm, true);
}
void unification_error::pretty_print(std::ostream& to, file_mgr& fm, type_mgr& mgr) {
type_error::pretty_print(to, fm);
to << "the expected type was:" << std::endl;
to << " \033[34m";
left->print(mgr, to);
to << std::endl << "\033[0mwhile the actual type was:" << std::endl;
to << " \033[32m";
right->print(mgr, to);
to << "\033[0m" << std::endl;
}

View File

@@ -0,0 +1,49 @@
#pragma once
#include <exception>
#include <optional>
#include "type.hpp"
#include "location.hh"
#include "parse_driver.hpp"
using maybe_location = std::optional<yy::location>;
class compiler_error : std::exception {
private:
std::string description;
maybe_location loc;
public:
compiler_error(std::string d, maybe_location l = std::nullopt)
: description(std::move(d)), loc(std::move(l)) {}
const char* what() const noexcept override;
void print_about(std::ostream& to);
void print_location(std::ostream& to, file_mgr& fm, bool highlight = false);
void pretty_print(std::ostream& to, file_mgr& fm);
};
class type_error : compiler_error {
private:
public:
type_error(std::string d, maybe_location l = std::nullopt)
: compiler_error(std::move(d), std::move(l)) {}
const char* what() const noexcept override;
void pretty_print(std::ostream& to, file_mgr& fm);
};
class unification_error : public type_error {
private:
type_ptr left;
type_ptr right;
public:
unification_error(type_ptr l, type_ptr r, maybe_location loc = std::nullopt)
: left(std::move(l)), right(std::move(r)),
type_error("failed to unify types", std::move(loc)) {}
void pretty_print(std::ostream& to, file_mgr& fm, type_mgr& mgr);
};

View File

@@ -0,0 +1,2 @@
data Bool = { True, False }
defn main = { 3 + True }

View File

@@ -0,0 +1 @@
defn main = { 1 2 3 4 5 }

View File

@@ -0,0 +1,8 @@
data List = { Nil, Cons Int List }
defn head l = {
case l of {
Nil -> { 0 }
Cons x y z -> { x }
}
}

View File

@@ -0,0 +1,6 @@
defn main = {
case True of {
n -> { 2 }
n -> { 1 }
}
}

View File

@@ -0,0 +1 @@
data Pair a a = { MkPair a a }

View File

@@ -0,0 +1,7 @@
defn main = {
case True of {
True -> { 1 }
False -> { 0 }
n -> { 2 }
}
}

View File

@@ -0,0 +1,5 @@
defn main = {
case True of {
True -> { 1 }
}
}

View File

@@ -0,0 +1,7 @@
defn add x y = { x + y }
defn main = {
case add of {
n -> { 1 }
}
}

View File

@@ -0,0 +1,7 @@
defn main = {
case True of {
n -> { 2 }
True -> { 1 }
False -> { 0 }
}
}

View File

@@ -0,0 +1,8 @@
data List = { Nil, Cons Int List }
defn head l = {
case l of {
Nil -> { 0 }
Cons x -> { x }
}
}

View File

@@ -0,0 +1,8 @@
data List = { Nil, Cons Int List }
defn head l = {
case l of {
Nil -> { 0 }
Cons x y z -> { x }
}
}

View File

@@ -0,0 +1,6 @@
defn main = {
case True of {
NotBool -> { 1 }
True -> { 2 }
}
}

View File

@@ -0,0 +1 @@
data Bool = { True, False }

View File

@@ -0,0 +1,3 @@
defn main = {
weird 1
}

View File

@@ -0,0 +1 @@
data Wrapper = { Wrap Weird }

View File

@@ -0,0 +1 @@
data Wrapper = { Wrap a }

View File

@@ -0,0 +1,3 @@
defn main = {
Weird 1
}

View File

@@ -0,0 +1 @@
data Wrapper = { Wrap (Int Bool) }

View File

@@ -0,0 +1,17 @@
data List a = { Nil, Cons a (List a) }
defn fix f = { let { defn x = { f x } } in { x } }
defn fixpointOnes fo = { Cons 1 fo }
defn sumTwo l = {
case l of {
Nil -> { 0 }
Cons x xs -> {
x + case xs of {
Nil -> { 0 }
Cons y ys -> { y }
}
}
}
}
defn main = { sumTwo (fix fixpointOnes) }

View File

@@ -0,0 +1,8 @@
data Bool = { True, False }
defn if c t e = {
case c of {
True -> { t }
False -> { e }
}
}
defn main = { if (if True False True) 11 3 }

View File

@@ -0,0 +1,19 @@
data List a = { Nil, Cons a (List a) }
defn sum l = {
case l of {
Nil -> { 0 }
Cons x xs -> { x + sum xs}
}
}
defn map f l = {
case l of {
Nil -> { Nil }
Cons x xs -> { Cons (f x) (map f xs) }
}
}
defn main = {
sum (map \x -> { x * x } (map (\x -> { x + x }) (Cons 1 (Cons 2 (Cons 3 Nil)))))
}

View File

@@ -0,0 +1,47 @@
data Bool = { True, False }
data List a = { Nil, Cons a (List a) }
defn if c t e = {
case c of {
True -> { t }
False -> { e }
}
}
defn mergeUntil l r p = {
let {
defn mergeLeft nl nr = {
case nl of {
Nil -> { Nil }
Cons x xs -> { if (p x) (Cons x (mergeRight xs nr)) Nil }
}
}
defn mergeRight nl nr = {
case nr of {
Nil -> { Nil }
Cons x xs -> { if (p x) (Cons x (mergeLeft nl xs)) Nil }
}
}
} in {
mergeLeft l r
}
}
defn const x y = { x }
defn sum l = {
case l of {
Nil -> { 0 }
Cons x xs -> { x + sum xs }
}
}
defn main = {
let {
defn firstList = { Cons 1 (Cons 3 (Cons 5 Nil)) }
defn secondList = { Cons 2 (Cons 4 (Cons 6 Nil)) }
} in {
sum (mergeUntil firstList secondList (const True))
}
}

View File

@@ -0,0 +1,32 @@
data List a = { Nil, Cons a (List a) }
defn map f l = {
case l of {
Nil -> { Nil }
Cons x xs -> { Cons (f x) (map f xs) }
}
}
defn foldl f b l = {
case l of {
Nil -> { b }
Cons x xs -> { foldl f (f b x) xs }
}
}
defn foldr f b l = {
case l of {
Nil -> { b }
Cons x xs -> { f x (foldr f b xs) }
}
}
defn list = { Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil))) }
defn add x y = { x + y }
defn sum l = { foldr add 0 l }
defn skipAdd x y = { y + 1 }
defn length l = { foldr skipAdd 0 l }
defn main = { sum list + length list }

View File

@@ -0,0 +1,25 @@
data Bool = { True, False }
data List = { Nil, Cons Int List }
defn if c t e = {
case c of {
True -> { t }
False -> { e }
}
}
defn oddEven l e = {
case l of {
Nil -> { e }
Cons x xs -> { evenOdd xs e }
}
}
defn evenOdd l e = {
case l of {
Nil -> { e }
Cons x xs -> { oddEven xs e }
}
}
defn main = { if (oddEven (Cons 1 (Cons 2 (Cons 3 Nil))) True) (oddEven (Cons 1 (Cons 2 (Cons 3 Nil))) 1) 3 }

View File

@@ -0,0 +1,23 @@
data Pair a b = { Pair a b }
defn packer = {
let {
data Packed a = { Packed a }
defn pack a = { Packed a }
defn unpack p = {
case p of {
Packed a -> { a }
}
}
} in {
Pair pack unpack
}
}
defn main = {
case packer of {
Pair pack unpack -> {
unpack (pack 3)
}
}
}

View File

@@ -0,0 +1,17 @@
data Pair a b = { MkPair a b }
defn fst p = {
case p of {
MkPair a b -> { a }
}
}
defn snd p = {
case p of {
MkPair a b -> { b }
}
}
defn pair = { MkPair 1 (MkPair 2 3) }
defn main = { fst pair + snd (snd pair) }

View File

@@ -0,0 +1,122 @@
data List = { Nil, Cons Nat List }
data Bool = { True, False }
data Nat = { O, S Nat }
defn if c t e = {
case c of {
True -> { t }
False -> { e }
}
}
defn toInt n = {
case n of {
O -> { 0 }
S np -> { 1 + toInt np }
}
}
defn lte n m = {
case m of {
O -> {
case n of {
O -> { True }
S np -> { False }
}
}
S mp -> {
case n of {
O -> { True }
S np -> { lte np mp }
}
}
}
}
defn minus n m = {
case m of {
O -> { n }
S mp -> {
case n of {
O -> { O }
S np -> {
minus np mp
}
}
}
}
}
defn mod n m = {
if (lte m n) (mod (minus n m) m) n
}
defn notDivisibleBy n m = {
case (mod m n) of {
O -> { False }
S mp -> { True }
}
}
defn filter f l = {
case l of {
Nil -> { Nil }
Cons x xs -> { if (f x) (Cons x (filter f xs)) (filter f xs) }
}
}
defn map f l = {
case l of {
Nil -> { Nil }
Cons x xs -> { Cons (f x) (map f xs) }
}
}
defn nats = {
Cons (S (S O)) (map S nats)
}
defn primesRec l = {
case l of {
Nil -> { Nil }
Cons p xs -> { Cons p (primesRec (filter (notDivisibleBy p) xs)) }
}
}
defn primes = {
primesRec nats
}
defn take n l = {
case l of {
Nil -> { Nil }
Cons x xs -> {
case n of {
O -> { Nil }
S np -> { Cons x (take np xs) }
}
}
}
}
defn head l = {
case l of {
Nil -> { O }
Cons x xs -> { x }
}
}
defn reverseAcc a l = {
case l of {
Nil -> { a }
Cons x xs -> { reverseAcc (Cons x a) xs }
}
}
defn reverse l = {
reverseAcc Nil l
}
defn main = {
toInt (head (reverse (take ((S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S (S O))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))) primes)))
}

View File

@@ -0,0 +1,31 @@
#include "../runtime.h"
void f_add(struct stack* s) {
struct node_num* left = (struct node_num*) eval(stack_peek(s, 0));
struct node_num* right = (struct node_num*) eval(stack_peek(s, 1));
stack_push(s, (struct node_base*) alloc_num(left->value + right->value));
}
void f_main(struct stack* s) {
// PushInt 320
stack_push(s, (struct node_base*) alloc_num(320));
// PushInt 6
stack_push(s, (struct node_base*) alloc_num(6));
// PushGlobal f_add (the function for +)
stack_push(s, (struct node_base*) alloc_global(f_add, 2));
struct node_base* left;
struct node_base* right;
// MkApp
left = stack_pop(s);
right = stack_pop(s);
stack_push(s, (struct node_base*) alloc_app(left, right));
// MkApp
left = stack_pop(s);
right = stack_pop(s);
stack_push(s, (struct node_base*) alloc_app(left, right));
}

View File

@@ -0,0 +1,2 @@
defn main = { sum 320 6 }
defn sum x y = { x + y }

View File

@@ -0,0 +1,3 @@
defn add x y = { x + y }
defn double x = { add x x }
defn main = { double 163 }

View File

@@ -0,0 +1,9 @@
data List a = { Nil, Cons a (List a) }
data Bool = { True, False }
defn length l = {
case l of {
Nil -> { 0 }
Cons x xs -> { 1 + length xs }
}
}
defn main = { length (Cons 1 (Cons 2 (Cons 3 Nil))) + length (Cons True (Cons False (Cons True Nil))) }

View File

@@ -0,0 +1,16 @@
data List = { Nil, Cons Int List }
defn add x y = { x + y }
defn mul x y = { x * y }
defn foldr f b l = {
case l of {
Nil -> { b }
Cons x xs -> { f x (foldr f b xs) }
}
}
defn main = {
foldr add 0 (Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil)))) +
foldr mul 1 (Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil))))
}

View File

@@ -0,0 +1,17 @@
data List = { Nil, Cons Int List }
defn sumZip l m = {
case l of {
Nil -> { 0 }
Cons x xs -> {
case m of {
Nil -> { 0 }
Cons y ys -> { x + y + sumZip xs ys }
}
}
}
}
defn ones = { Cons 1 ones }
defn main = { sumZip ones (Cons 1 (Cons 2 (Cons 3 Nil))) }

View File

@@ -0,0 +1,76 @@
#include "global_scope.hpp"
#include "ast.hpp"
void global_function::compile() {
env_ptr new_env = env_ptr(new env_offset(0, nullptr));
for(auto it = params.rbegin(); it != params.rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env));
}
body->compile(new_env, instructions);
instructions.push_back(instruction_ptr(new instruction_update(params.size())));
instructions.push_back(instruction_ptr(new instruction_pop(params.size())));
}
void global_function::declare_llvm(llvm_context& ctx) {
generated_function = ctx.create_custom_function(name, params.size());
}
void global_function::generate_llvm(llvm_context& ctx) {
ctx.get_builder().SetInsertPoint(&generated_function->getEntryBlock());
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, generated_function);
}
ctx.get_builder().CreateRetVoid();
}
void global_constructor::generate_llvm(llvm_context& ctx) {
auto new_function =
ctx.create_custom_function(name, arity);
std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_pack(tag, arity)));
instructions.push_back(instruction_ptr(new instruction_update(0)));
ctx.get_builder().SetInsertPoint(&new_function->getEntryBlock());
for (auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function);
}
ctx.get_builder().CreateRetVoid();
}
global_function& global_scope::add_function(
const std::string& n,
std::vector<std::string> ps,
ast_ptr b) {
auto name = mng->new_mangled_name(n);
global_function* new_function =
new global_function(std::move(name), std::move(ps), std::move(b));
functions.push_back(global_function_ptr(new_function));
return *new_function;
}
global_constructor& global_scope::add_constructor(
const std::string& n,
int8_t t,
size_t a) {
auto name = mng->new_mangled_name(n);
global_constructor* new_constructor = new global_constructor(name, t, a);
constructors.push_back(global_constructor_ptr(new_constructor));
return *new_constructor;
}
void global_scope::compile() {
for(auto& function : functions) {
function->compile();
}
}
void global_scope::generate_llvm(llvm_context& ctx) {
for(auto& constructor : constructors) {
constructor->generate_llvm(ctx);
}
for(auto& function : functions) {
function->declare_llvm(ctx);
}
for(auto& function : functions) {
function->generate_llvm(ctx);
}
}

View File

@@ -0,0 +1,60 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include <llvm/IR/Function.h>
#include "instruction.hpp"
#include "mangler.hpp"
struct ast;
using ast_ptr = std::unique_ptr<ast>;
struct global_function {
std::string name;
std::vector<std::string> params;
ast_ptr body;
std::vector<instruction_ptr> instructions;
llvm::Function* generated_function;
global_function(std::string n, std::vector<std::string> ps, ast_ptr b)
: name(std::move(n)), params(std::move(ps)), body(std::move(b)) {}
void compile();
void declare_llvm(llvm_context& ctx);
void generate_llvm(llvm_context& ctx);
};
using global_function_ptr = std::unique_ptr<global_function>;
struct global_constructor {
std::string name;
int8_t tag;
size_t arity;
global_constructor(std::string n, int8_t t, size_t a)
: name(std::move(n)), tag(t), arity(a) {}
void generate_llvm(llvm_context& ctx);
};
using global_constructor_ptr = std::unique_ptr<global_constructor>;
class global_scope {
private:
std::vector<global_function_ptr> functions;
std::vector<global_constructor_ptr> constructors;
mangler* mng;
public:
global_scope(mangler& m) : mng(&m) {}
global_function& add_function(
const std::string& n,
std::vector<std::string> ps,
ast_ptr b);
global_constructor& add_constructor(const std::string& n, int8_t t, size_t a);
void compile();
void generate_llvm(llvm_context& ctx);
};

114
code/compiler/13/graph.cpp Normal file
View File

@@ -0,0 +1,114 @@
#include "graph.hpp"
std::set<function_graph::edge> function_graph::compute_transitive_edges() {
std::set<edge> transitive_edges;
transitive_edges.insert(edges.begin(), edges.end());
for(auto& connector : adjacency_lists) {
for(auto& from : adjacency_lists) {
edge to_connector { from.first, connector.first };
for(auto& to : adjacency_lists) {
edge full_jump { from.first, to.first };
if(transitive_edges.find(full_jump) != transitive_edges.end()) continue;
edge from_connector { connector.first, to.first };
if(transitive_edges.find(to_connector) != transitive_edges.end() &&
transitive_edges.find(from_connector) != transitive_edges.end())
transitive_edges.insert(std::move(full_jump));
}
}
}
return transitive_edges;
}
void function_graph::create_groups(
const std::set<edge>& transitive_edges,
std::map<function, group_id>& group_ids,
std::map<group_id, data_ptr>& group_data_map) {
group_id id_counter = 0;
for(auto& vertex : adjacency_lists) {
if(group_ids.find(vertex.first) != group_ids.end())
continue;
data_ptr new_group(new group_data);
new_group->functions.insert(vertex.first);
group_data_map[id_counter] = new_group;
group_ids[vertex.first] = id_counter;
for(auto& other_vertex : adjacency_lists) {
if(transitive_edges.find({vertex.first, other_vertex.first}) != transitive_edges.end() &&
transitive_edges.find({other_vertex.first, vertex.first}) != transitive_edges.end()) {
group_ids[other_vertex.first] = id_counter;
new_group->functions.insert(other_vertex.first);
}
}
id_counter++;
}
}
void function_graph::create_edges(
std::map<function, group_id>& group_ids,
std::map<group_id, data_ptr>& group_data_map) {
std::set<std::pair<group_id, group_id>> group_edges;
for(auto& vertex : adjacency_lists) {
auto vertex_id = group_ids[vertex.first];
auto& vertex_data = group_data_map[vertex_id];
for(auto& other_vertex : vertex.second) {
auto other_id = group_ids[other_vertex];
if(vertex_id == other_id) continue;
if(group_edges.find({vertex_id, other_id}) != group_edges.end())
continue;
group_edges.insert({vertex_id, other_id});
vertex_data->adjacency_list.insert(other_id);
group_data_map[other_id]->indegree++;
}
}
}
std::vector<group_ptr> function_graph::generate_order(
std::map<function, group_id>& group_ids,
std::map<group_id, data_ptr>& group_data_map) {
std::queue<group_id> id_queue;
std::vector<group_ptr> output;
for(auto& group : group_data_map) {
if(group.second->indegree == 0) id_queue.push(group.first);
}
while(!id_queue.empty()) {
auto new_id = id_queue.front();
auto& group_data = group_data_map[new_id];
group_ptr output_group(new group);
output_group->members = std::move(group_data->functions);
id_queue.pop();
for(auto& adjacent_group : group_data->adjacency_list) {
if(--group_data_map[adjacent_group]->indegree == 0)
id_queue.push(adjacent_group);
}
output.push_back(std::move(output_group));
}
return output;
}
std::set<function>& function_graph::add_function(const function& f) {
auto adjacency_list_it = adjacency_lists.find(f);
if(adjacency_list_it != adjacency_lists.end()) {
return adjacency_list_it->second;
} else {
return adjacency_lists[f] = { };
}
}
void function_graph::add_edge(const function& from, const function& to) {
add_function(from).insert(to);
edges.insert({ from, to });
}
std::vector<group_ptr> function_graph::compute_order() {
std::set<edge> transitive_edges = compute_transitive_edges();
std::map<function, group_id> group_ids;
std::map<group_id, data_ptr> group_data_map;
create_groups(transitive_edges, group_ids, group_data_map);
create_edges(group_ids, group_data_map);
return generate_order(group_ids, group_data_map);
}

View File

@@ -0,0 +1,54 @@
#pragma once
#include <algorithm>
#include <cstddef>
#include <queue>
#include <set>
#include <string>
#include <map>
#include <memory>
#include <vector>
using function = std::string;
struct group {
std::set<function> members;
};
using group_ptr = std::unique_ptr<group>;
class function_graph {
private:
using group_id = size_t;
struct group_data {
std::set<function> functions;
std::set<group_id> adjacency_list;
size_t indegree;
group_data() : indegree(0) {}
};
using data_ptr = std::shared_ptr<group_data>;
using edge = std::pair<function, function>;
using group_edge = std::pair<group_id, group_id>;
std::map<function, std::set<function>> adjacency_lists;
std::set<edge> edges;
std::set<edge> compute_transitive_edges();
void create_groups(
const std::set<edge>&,
std::map<function, group_id>&,
std::map<group_id, data_ptr>&);
void create_edges(
std::map<function, group_id>&,
std::map<group_id, data_ptr>&);
std::vector<group_ptr> generate_order(
std::map<function, group_id>&,
std::map<group_id, data_ptr>&);
public:
std::set<function>& add_function(const function& f);
void add_edge(const function& from, const function& to);
std::vector<group_ptr> compute_order();
};

View File

@@ -0,0 +1,177 @@
#include "instruction.hpp"
#include "llvm_context.hpp"
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
using namespace llvm;
static void print_indent(int n, std::ostream& to) {
while(n--) to << " ";
}
void instruction_pushint::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "PushInt(" << value << ")" << std::endl;
}
void instruction_pushint::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_push(f, ctx.create_num(f, ctx.create_i32(value)));
}
void instruction_pushglobal::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "PushGlobal(" << name << ")" << std::endl;
}
void instruction_pushglobal::gen_llvm(llvm_context& ctx, Function* f) const {
auto& global_f = ctx.get_custom_function(name);
auto arity = ctx.create_i32(global_f.arity);
ctx.create_push(f, ctx.create_global(f, global_f.function, arity));
}
void instruction_push::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Push(" << offset << ")" << std::endl;
}
void instruction_push::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_push(f, ctx.create_peek(f, ctx.create_size(offset)));
}
void instruction_pop::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Pop(" << count << ")" << std::endl;
}
void instruction_pop::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_popn(f, ctx.create_size(count));
}
void instruction_mkapp::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "MkApp()" << std::endl;
}
void instruction_mkapp::gen_llvm(llvm_context& ctx, Function* f) const {
auto left = ctx.create_pop(f);
auto right = ctx.create_pop(f);
ctx.create_push(f, ctx.create_app(f, left, right));
}
void instruction_update::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Update(" << offset << ")" << std::endl;
}
void instruction_update::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_update(f, ctx.create_size(offset));
}
void instruction_pack::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Pack(" << tag << ", " << size << ")" << std::endl;
}
void instruction_pack::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_pack(f, ctx.create_size(size), ctx.create_i8(tag));
}
void instruction_split::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Split()" << std::endl;
}
void instruction_split::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_split(f, ctx.create_size(size));
}
void instruction_jump::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Jump(" << std::endl;
for(auto& instruction_set : branches) {
for(auto& instruction : instruction_set) {
instruction->print(indent + 2, to);
}
to << std::endl;
}
print_indent(indent, to);
to << ")" << std::endl;
}
void instruction_jump::gen_llvm(llvm_context& ctx, Function* f) const {
auto top_node = ctx.create_peek(f, ctx.create_size(0));
auto tag = ctx.unwrap_data_tag(top_node);
auto safety_block = ctx.create_basic_block("safety", f);
auto switch_op = ctx.get_builder().CreateSwitch(tag, safety_block, tag_mappings.size());
std::vector<BasicBlock*> blocks;
for(auto& branch : branches) {
auto branch_block = ctx.create_basic_block("branch", f);
ctx.get_builder().SetInsertPoint(branch_block);
for(auto& instruction : branch) {
instruction->gen_llvm(ctx, f);
}
ctx.get_builder().CreateBr(safety_block);
blocks.push_back(branch_block);
}
for(auto& mapping : tag_mappings) {
switch_op->addCase(ctx.create_i8(mapping.first), blocks[mapping.second]);
}
ctx.get_builder().SetInsertPoint(safety_block);
}
void instruction_slide::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Slide(" << offset << ")" << std::endl;
}
void instruction_slide::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_slide(f, ctx.create_size(offset));
}
void instruction_binop::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "BinOp(" << op_action(op) << ")" << std::endl;
}
void instruction_binop::gen_llvm(llvm_context& ctx, Function* f) const {
auto left_int = ctx.unwrap_num(ctx.create_pop(f));
auto right_int = ctx.unwrap_num(ctx.create_pop(f));
llvm::Value* result;
switch(op) {
case PLUS: result = ctx.get_builder().CreateAdd(left_int, right_int); break;
case MINUS: result = ctx.get_builder().CreateSub(left_int, right_int); break;
case TIMES: result = ctx.get_builder().CreateMul(left_int, right_int); break;
case DIVIDE: result = ctx.get_builder().CreateSDiv(left_int, right_int); break;
}
ctx.create_push(f, ctx.create_num(f, result));
}
void instruction_eval::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Eval()" << std::endl;
}
void instruction_eval::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_unwind(f);
}
void instruction_alloc::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Alloc(" << amount << ")" << std::endl;
}
void instruction_alloc::gen_llvm(llvm_context& ctx, Function* f) const {
ctx.create_alloc(f, ctx.create_size(amount));
}
void instruction_unwind::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "Unwind()" << std::endl;
}
void instruction_unwind::gen_llvm(llvm_context& ctx, Function* f) const {
// Nothing
}

View File

@@ -0,0 +1,142 @@
#pragma once
#include <llvm/IR/Function.h>
#include <string>
#include <memory>
#include <vector>
#include <map>
#include <ostream>
#include "binop.hpp"
#include "llvm_context.hpp"
struct instruction {
virtual ~instruction() = default;
virtual void print(int indent, std::ostream& to) const = 0;
virtual void gen_llvm(llvm_context& ctx, llvm::Function* f) const = 0;
};
using instruction_ptr = std::unique_ptr<instruction>;
struct instruction_pushint : public instruction {
int value;
instruction_pushint(int v)
: value(v) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_pushglobal : public instruction {
std::string name;
instruction_pushglobal(std::string n)
: name(std::move(n)) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_push : public instruction {
int offset;
instruction_push(int o)
: offset(o) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_pop : public instruction {
int count;
instruction_pop(int c)
: count(c) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_mkapp : public instruction {
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_update : public instruction {
int offset;
instruction_update(int o)
: offset(o) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_pack : public instruction {
int tag;
int size;
instruction_pack(int t, int s)
: tag(t), size(s) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_split : public instruction {
int size;
instruction_split(int s)
: size(s) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_jump : public instruction {
std::vector<std::vector<instruction_ptr>> branches;
std::map<int, int> tag_mappings;
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_slide : public instruction {
int offset;
instruction_slide(int o)
: offset(o) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_binop : public instruction {
binop op;
instruction_binop(binop o)
: op(o) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_eval : public instruction {
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_alloc : public instruction {
int amount;
instruction_alloc(int a)
: amount(a) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_unwind : public instruction {
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};

View File

@@ -0,0 +1,294 @@
#include "llvm_context.hpp"
#include <llvm/IR/DerivedTypes.h>
using namespace llvm;
void llvm_context::create_types() {
stack_type = StructType::create(ctx, "stack");
gmachine_type = StructType::create(ctx, "gmachine");
stack_ptr_type = PointerType::getUnqual(stack_type);
gmachine_ptr_type = PointerType::getUnqual(gmachine_type);
tag_type = IntegerType::getInt8Ty(ctx);
struct_types["node_base"] = StructType::create(ctx, "node_base");
struct_types["node_app"] = StructType::create(ctx, "node_app");
struct_types["node_num"] = StructType::create(ctx, "node_num");
struct_types["node_global"] = StructType::create(ctx, "node_global");
struct_types["node_ind"] = StructType::create(ctx, "node_ind");
struct_types["node_data"] = StructType::create(ctx, "node_data");
node_ptr_type = PointerType::getUnqual(struct_types.at("node_base"));
function_type = FunctionType::get(Type::getVoidTy(ctx), { gmachine_ptr_type }, false);
gmachine_type->setBody(
stack_ptr_type,
node_ptr_type,
IntegerType::getInt64Ty(ctx),
IntegerType::getInt64Ty(ctx)
);
struct_types.at("node_base")->setBody(
IntegerType::getInt32Ty(ctx),
IntegerType::getInt8Ty(ctx),
node_ptr_type
);
struct_types.at("node_app")->setBody(
struct_types.at("node_base"),
node_ptr_type,
node_ptr_type
);
struct_types.at("node_num")->setBody(
struct_types.at("node_base"),
IntegerType::getInt32Ty(ctx)
);
struct_types.at("node_global")->setBody(
struct_types.at("node_base"),
FunctionType::get(Type::getVoidTy(ctx), { stack_ptr_type }, false)
);
struct_types.at("node_ind")->setBody(
struct_types.at("node_base"),
node_ptr_type
);
struct_types.at("node_data")->setBody(
struct_types.at("node_base"),
IntegerType::getInt8Ty(ctx),
PointerType::getUnqual(node_ptr_type)
);
}
void llvm_context::create_functions() {
auto void_type = Type::getVoidTy(ctx);
auto sizet_type = IntegerType::get(ctx, sizeof(size_t) * 8);
functions["stack_init"] = Function::Create(
FunctionType::get(void_type, { stack_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"stack_init",
&module
);
functions["stack_free"] = Function::Create(
FunctionType::get(void_type, { stack_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"stack_free",
&module
);
functions["stack_push"] = Function::Create(
FunctionType::get(void_type, { stack_ptr_type, node_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"stack_push",
&module
);
functions["stack_pop"] = Function::Create(
FunctionType::get(node_ptr_type, { stack_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"stack_pop",
&module
);
functions["stack_peek"] = Function::Create(
FunctionType::get(node_ptr_type, { stack_ptr_type, sizet_type }, false),
Function::LinkageTypes::ExternalLinkage,
"stack_peek",
&module
);
functions["stack_popn"] = Function::Create(
FunctionType::get(void_type, { stack_ptr_type, sizet_type }, false),
Function::LinkageTypes::ExternalLinkage,
"stack_popn",
&module
);
functions["gmachine_slide"] = Function::Create(
FunctionType::get(void_type, { gmachine_ptr_type, sizet_type }, false),
Function::LinkageTypes::ExternalLinkage,
"gmachine_slide",
&module
);
functions["gmachine_update"] = Function::Create(
FunctionType::get(void_type, { gmachine_ptr_type, sizet_type }, false),
Function::LinkageTypes::ExternalLinkage,
"gmachine_update",
&module
);
functions["gmachine_alloc"] = Function::Create(
FunctionType::get(void_type, { gmachine_ptr_type, sizet_type }, false),
Function::LinkageTypes::ExternalLinkage,
"gmachine_alloc",
&module
);
functions["gmachine_pack"] = Function::Create(
FunctionType::get(void_type, { gmachine_ptr_type, sizet_type, tag_type }, false),
Function::LinkageTypes::ExternalLinkage,
"gmachine_pack",
&module
);
functions["gmachine_split"] = Function::Create(
FunctionType::get(void_type, { gmachine_ptr_type, sizet_type }, false),
Function::LinkageTypes::ExternalLinkage,
"gmachine_split",
&module
);
functions["gmachine_track"] = Function::Create(
FunctionType::get(node_ptr_type, { gmachine_ptr_type, node_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"gmachine_track",
&module
);
auto int32_type = IntegerType::getInt32Ty(ctx);
functions["alloc_app"] = Function::Create(
FunctionType::get(node_ptr_type, { node_ptr_type, node_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"alloc_app",
&module
);
functions["alloc_num"] = Function::Create(
FunctionType::get(node_ptr_type, { int32_type }, false),
Function::LinkageTypes::ExternalLinkage,
"alloc_num",
&module
);
functions["alloc_global"] = Function::Create(
FunctionType::get(node_ptr_type, { function_type, int32_type }, false),
Function::LinkageTypes::ExternalLinkage,
"alloc_global",
&module
);
functions["alloc_ind"] = Function::Create(
FunctionType::get(node_ptr_type, { node_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"alloc_ind",
&module
);
functions["unwind"] = Function::Create(
FunctionType::get(void_type, { gmachine_ptr_type }, false),
Function::LinkageTypes::ExternalLinkage,
"unwind",
&module
);
}
IRBuilder<>& llvm_context::get_builder() {
return builder;
}
Module& llvm_context::get_module() {
return module;
}
BasicBlock* llvm_context::create_basic_block(const std::string& name, llvm::Function* f) {
return BasicBlock::Create(ctx, name, f);
}
ConstantInt* llvm_context::create_i8(int8_t i) {
return ConstantInt::get(ctx, APInt(8, i));
}
ConstantInt* llvm_context::create_i32(int32_t i) {
return ConstantInt::get(ctx, APInt(32, i));
}
ConstantInt* llvm_context::create_size(size_t i) {
return ConstantInt::get(ctx, APInt(sizeof(size_t) * 8, i));
}
Value* llvm_context::create_pop(Function* f) {
auto pop_f = functions.at("stack_pop");
return builder.CreateCall(pop_f, { unwrap_gmachine_stack_ptr(f->arg_begin()) });
}
Value* llvm_context::create_peek(Function* f, Value* off) {
auto peek_f = functions.at("stack_peek");
return builder.CreateCall(peek_f, { unwrap_gmachine_stack_ptr(f->arg_begin()), off });
}
void llvm_context::create_push(Function* f, Value* v) {
auto push_f = functions.at("stack_push");
builder.CreateCall(push_f, { unwrap_gmachine_stack_ptr(f->arg_begin()), v });
}
void llvm_context::create_popn(Function* f, Value* off) {
auto popn_f = functions.at("stack_popn");
builder.CreateCall(popn_f, { unwrap_gmachine_stack_ptr(f->arg_begin()), off });
}
void llvm_context::create_update(Function* f, Value* off) {
auto update_f = functions.at("gmachine_update");
builder.CreateCall(update_f, { f->arg_begin(), off });
}
void llvm_context::create_pack(Function* f, Value* c, Value* t) {
auto pack_f = functions.at("gmachine_pack");
builder.CreateCall(pack_f, { f->arg_begin(), c, t });
}
void llvm_context::create_split(Function* f, Value* c) {
auto split_f = functions.at("gmachine_split");
builder.CreateCall(split_f, { f->arg_begin(), c });
}
void llvm_context::create_slide(Function* f, Value* off) {
auto slide_f = functions.at("gmachine_slide");
builder.CreateCall(slide_f, { f->arg_begin(), off });
}
void llvm_context::create_alloc(Function* f, Value* n) {
auto alloc_f = functions.at("gmachine_alloc");
builder.CreateCall(alloc_f, { f->arg_begin(), n });
}
Value* llvm_context::create_track(Function* f, Value* v) {
auto track_f = functions.at("gmachine_track");
return builder.CreateCall(track_f, { f->arg_begin(), v });
}
void llvm_context::create_unwind(Function* f) {
auto unwind_f = functions.at("unwind");
builder.CreateCall(unwind_f, { f->args().begin() });
}
Value* llvm_context::unwrap_gmachine_stack_ptr(Value* g) {
auto offset_0 = create_i32(0);
return builder.CreateGEP(g, { offset_0, offset_0 });
}
Value* llvm_context::unwrap_num(Value* v) {
auto num_ptr_type = PointerType::getUnqual(struct_types.at("node_num"));
auto cast = builder.CreatePointerCast(v, num_ptr_type);
auto offset_0 = create_i32(0);
auto offset_1 = create_i32(1);
auto int_ptr = builder.CreateGEP(cast, { offset_0, offset_1 });
return builder.CreateLoad(int_ptr);
}
Value* llvm_context::create_num(Function* f, Value* v) {
auto alloc_num_f = functions.at("alloc_num");
auto alloc_num_call = builder.CreateCall(alloc_num_f, { v });
return create_track(f, alloc_num_call);
}
Value* llvm_context::unwrap_data_tag(Value* v) {
auto data_ptr_type = PointerType::getUnqual(struct_types.at("node_data"));
auto cast = builder.CreatePointerCast(v, data_ptr_type);
auto offset_0 = create_i32(0);
auto offset_1 = create_i32(1);
auto tag_ptr = builder.CreateGEP(cast, { offset_0, offset_1 });
return builder.CreateLoad(tag_ptr);
}
Value* llvm_context::create_global(Function* f, Value* gf, Value* a) {
auto alloc_global_f = functions.at("alloc_global");
auto alloc_global_call = builder.CreateCall(alloc_global_f, { gf, a });
return create_track(f, alloc_global_call);
}
Value* llvm_context::create_app(Function* f, Value* l, Value* r) {
auto alloc_app_f = functions.at("alloc_app");
auto alloc_app_call = builder.CreateCall(alloc_app_f, { l, r });
return create_track(f, alloc_app_call);
}
llvm::Function* llvm_context::create_custom_function(const std::string& name, int32_t arity) {
auto void_type = llvm::Type::getVoidTy(ctx);
auto new_function = llvm::Function::Create(
function_type,
llvm::Function::LinkageTypes::ExternalLinkage,
"f_" + name,
&module
);
auto start_block = llvm::BasicBlock::Create(ctx, "entry", new_function);
auto new_custom_f = custom_function_ptr(new custom_function());
new_custom_f->arity = arity;
new_custom_f->function = new_function;
custom_functions["f_" + name] = std::move(new_custom_f);
return new_function;
}
llvm_context::custom_function& llvm_context::get_custom_function(const std::string& name) {
return *custom_functions.at("f_" + name);
}

View File

@@ -0,0 +1,81 @@
#pragma once
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
#include <map>
class llvm_context {
public:
struct custom_function {
llvm::Function* function;
int32_t arity;
};
using custom_function_ptr = std::unique_ptr<custom_function>;
private:
llvm::LLVMContext ctx;
llvm::IRBuilder<> builder;
llvm::Module module;
std::map<std::string, custom_function_ptr> custom_functions;
std::map<std::string, llvm::Function*> functions;
std::map<std::string, llvm::StructType*> struct_types;
llvm::StructType* stack_type;
llvm::StructType* gmachine_type;
llvm::PointerType* stack_ptr_type;
llvm::PointerType* gmachine_ptr_type;
llvm::PointerType* node_ptr_type;
llvm::IntegerType* tag_type;
llvm::FunctionType* function_type;
void create_types();
void create_functions();
public:
llvm_context()
: builder(ctx), module("bloglang", ctx) {
create_types();
create_functions();
}
llvm::IRBuilder<>& get_builder();
llvm::Module& get_module();
llvm::BasicBlock* create_basic_block(const std::string& name, llvm::Function* f);
llvm::ConstantInt* create_i8(int8_t);
llvm::ConstantInt* create_i32(int32_t);
llvm::ConstantInt* create_size(size_t);
llvm::Value* create_pop(llvm::Function*);
llvm::Value* create_peek(llvm::Function*, llvm::Value*);
void create_push(llvm::Function*, llvm::Value*);
void create_popn(llvm::Function*, llvm::Value*);
void create_update(llvm::Function*, llvm::Value*);
void create_pack(llvm::Function*, llvm::Value*, llvm::Value*);
void create_split(llvm::Function*, llvm::Value*);
void create_slide(llvm::Function*, llvm::Value*);
void create_alloc(llvm::Function*, llvm::Value*);
llvm::Value* create_track(llvm::Function*, llvm::Value*);
void create_unwind(llvm::Function*);
llvm::Value* unwrap_gmachine_stack_ptr(llvm::Value*);
llvm::Value* unwrap_num(llvm::Value*);
llvm::Value* create_num(llvm::Function*, llvm::Value*);
llvm::Value* unwrap_data_tag(llvm::Value*);
llvm::Value* create_global(llvm::Function*, llvm::Value*, llvm::Value*);
llvm::Value* create_app(llvm::Function*, llvm::Value*, llvm::Value*);
llvm::Function* create_custom_function(const std::string& name, int32_t arity);
custom_function& get_custom_function(const std::string& name);
};

27
code/compiler/13/main.cpp Normal file
View File

@@ -0,0 +1,27 @@
#include "ast.hpp"
#include <iostream>
#include "parser.hpp"
#include "compiler.hpp"
#include "error.hpp"
void yy::parser::error(const yy::location& loc, const std::string& msg) {
std::cerr << "An error occured: " << msg << std::endl;
}
int main(int argc, char** argv) {
if(argc != 2) {
std::cerr << "please enter a file to compile." << std::endl;
exit(1);
}
compiler cmp(argv[1]);
try {
cmp("program.o");
} catch(unification_error& err) {
err.pretty_print(std::cerr, cmp.get_file_manager(), cmp.get_type_manager());
} catch(type_error& err) {
err.pretty_print(std::cerr, cmp.get_file_manager());
} catch (compiler_error& err) {
err.pretty_print(std::cerr, cmp.get_file_manager());
}
}

View File

@@ -0,0 +1,17 @@
#include "mangler.hpp"
std::string mangler::new_mangled_name(const std::string& n) {
auto occurence_it = occurence_count.find(n);
int occurence = 0;
if(occurence_it != occurence_count.end()) {
occurence = occurence_it->second + 1;
}
occurence_count[n] = occurence;
std::string final_name = n;
if (occurence != 0) {
final_name += "_";
final_name += std::to_string(occurence);
}
return final_name;
}

View File

@@ -0,0 +1,11 @@
#pragma once
#include <string>
#include <map>
class mangler {
private:
std::map<std::string, int> occurence_count;
public:
std::string new_mangled_name(const std::string& str);
};

View File

@@ -0,0 +1,72 @@
#include "parse_driver.hpp"
#include "scanner.hpp"
#include <sstream>
file_mgr::file_mgr() : file_offset(0) {
line_offsets.push_back(0);
}
void file_mgr::write(const char* buf, size_t len) {
string_stream.write(buf, len);
file_offset += len;
}
void file_mgr::mark_line() {
line_offsets.push_back(file_offset);
}
void file_mgr::finalize() {
file_contents = string_stream.str();
}
size_t file_mgr::get_index(int line, int column) const {
assert(line > 0 && line <= line_offsets.size());
return line_offsets.at(line-1) + column - 1;
}
size_t file_mgr::get_line_end(int line) const {
if(line == line_offsets.size()) return file_contents.size();
return get_index(line+1, 1);
}
void file_mgr::print_location(
std::ostream& stream,
const yy::location& loc,
bool highlight) const {
size_t print_start = get_index(loc.begin.line, 1);
size_t highlight_start = get_index(loc.begin.line, loc.begin.column);
size_t highlight_end = get_index(loc.end.line, loc.end.column);
size_t print_end = get_line_end(loc.end.line);
const char* content = file_contents.c_str();
stream.write(content + print_start, highlight_start - print_start);
if(highlight) stream << "\033[4;31m";
stream.write(content + highlight_start, highlight_end - highlight_start);
if(highlight) stream << "\033[0m";
stream.write(content + highlight_end, print_end - highlight_end);
}
bool parse_driver::operator()() {
FILE* stream = fopen(file_name.c_str(), "r");
if(!stream) return false;
yyscan_t scanner;
yylex_init(&scanner);
yyset_in(stream, scanner);
yy::parser parser(scanner, *this);
parser();
yylex_destroy(scanner);
fclose(stream);
file_m->finalize();
return true;
}
yy::location& parse_driver::get_current_location() {
return location;
}
file_mgr& parse_driver::get_file_manager() const {
return *file_m;
}
definition_group& parse_driver::get_global_defs() const {
return *global_defs;
}

View File

@@ -0,0 +1,58 @@
#pragma once
#include <string>
#include <fstream>
#include <sstream>
#include "definition.hpp"
#include "location.hh"
#include "parser.hpp"
struct parse_driver;
void scanner_init(parse_driver* d, yyscan_t* scanner);
void scanner_destroy(yyscan_t* scanner);
class file_mgr {
private:
std::ostringstream string_stream;
std::string file_contents;
size_t file_offset;
std::vector<size_t> line_offsets;
public:
file_mgr();
void write(const char* buffer, size_t len);
void mark_line();
void finalize();
size_t get_index(int line, int column) const;
size_t get_line_end(int line) const;
void print_location(
std::ostream& stream,
const yy::location& loc,
bool highlight = true) const;
};
class parse_driver {
private:
std::string file_name;
yy::location location;
definition_group* global_defs;
file_mgr* file_m;
public:
parse_driver(
file_mgr& mgr,
definition_group& defs,
const std::string& file)
: file_name(file), file_m(&mgr), global_defs(&defs) {}
bool operator()();
yy::location& get_current_location();
file_mgr& get_file_manager() const;
definition_group& get_global_defs() const;
};
#define YY_DECL yy::parser::symbol_type yylex(yyscan_t yyscanner, parse_driver& drv)
YY_DECL;

View File

@@ -0,0 +1,48 @@
#include "parsed_type.hpp"
#include <sstream>
#include "type.hpp"
#include "type_env.hpp"
#include "error.hpp"
type_ptr parsed_type_app::to_type(
const std::set<std::string>& vars,
const type_env& e) const {
auto parent_type = e.lookup_type(name);
if(parent_type == nullptr)
throw type_error("no such type or type constructor " + name);
type_base* base_type;
if(!(base_type = dynamic_cast<type_base*>(parent_type.get())))
throw type_error("invalid type " + name);
if(base_type->arity != arguments.size()) {
std::ostringstream error_stream;
error_stream << "invalid application of type ";
error_stream << name;
error_stream << " (" << base_type->arity << " argument(s) expected, ";
error_stream << "but " << arguments.size() << " provided)";
throw type_error(error_stream.str());
}
type_app* new_app = new type_app(std::move(parent_type));
type_ptr to_return(new_app);
for(auto& arg : arguments) {
new_app->arguments.push_back(arg->to_type(vars, e));
}
return to_return;
}
type_ptr parsed_type_var::to_type(
const std::set<std::string>& vars,
const type_env& e) const {
if(vars.find(var) == vars.end())
throw type_error("the type variable " + var + " was not explicitly declared.");
return type_ptr(new type_var(var));
}
type_ptr parsed_type_arr::to_type(
const std::set<std::string>& vars,
const type_env& env) const {
auto new_left = left->to_type(vars, env);
auto new_right = right->to_type(vars, env);
return type_ptr(new type_arr(std::move(new_left), std::move(new_right)));
}

View File

@@ -0,0 +1,43 @@
#pragma once
#include <memory>
#include <set>
#include <string>
#include "type_env.hpp"
struct parsed_type {
virtual type_ptr to_type(
const std::set<std::string>& vars,
const type_env& env) const = 0;
};
using parsed_type_ptr = std::unique_ptr<parsed_type>;
struct parsed_type_app : parsed_type {
std::string name;
std::vector<parsed_type_ptr> arguments;
parsed_type_app(
std::string n,
std::vector<parsed_type_ptr> as)
: name(std::move(n)), arguments(std::move(as)) {}
type_ptr to_type(const std::set<std::string>& vars, const type_env& env) const;
};
struct parsed_type_var : parsed_type {
std::string var;
parsed_type_var(std::string v) : var(std::move(v)) {}
type_ptr to_type(const std::set<std::string>& vars, const type_env& env) const;
};
struct parsed_type_arr : parsed_type {
parsed_type_ptr left;
parsed_type_ptr right;
parsed_type_arr(parsed_type_ptr l, parsed_type_ptr r)
: left(std::move(l)), right(std::move(r)) {}
type_ptr to_type(const std::set<std::string>& vars, const type_env& env) const;
};

180
code/compiler/13/parser.y Normal file
View File

@@ -0,0 +1,180 @@
%code requires {
#include <string>
#include <vector>
#include "ast.hpp"
#include "definition.hpp"
#include "parser.hpp"
#include "parsed_type.hpp"
class parse_driver;
using yyscan_t = void*;
}
%param { yyscan_t scanner }
%param { parse_driver& drv }
%code {
#include "parse_driver.hpp"
}
%token BACKSLASH
%token PLUS
%token TIMES
%token MINUS
%token DIVIDE
%token <int> INT
%token DEFN
%token DATA
%token CASE
%token OF
%token LET
%token IN
%token OCURLY
%token CCURLY
%token OPAREN
%token CPAREN
%token COMMA
%token ARROW
%token EQUAL
%token <std::string> LID
%token <std::string> UID
%language "c++"
%define api.value.type variant
%define api.token.constructor
%locations
%type <std::vector<std::string>> lowercaseParams
%type <std::vector<branch_ptr>> branches
%type <std::vector<constructor_ptr>> constructors
%type <std::vector<parsed_type_ptr>> typeList
%type <definition_group> definitions
%type <parsed_type_ptr> type nonArrowType typeListElement
%type <ast_ptr> aAdd aMul case let lambda app appBase
%type <definition_data_ptr> data
%type <definition_defn_ptr> defn
%type <branch_ptr> branch
%type <pattern_ptr> pattern
%type <constructor_ptr> constructor
%start program
%%
program
: definitions { $1.vis = visibility::global; std::swap(drv.get_global_defs(), $1); }
;
definitions
: definitions defn { $$ = std::move($1); auto name = $2->name; $$.defs_defn[name] = std::move($2); }
| definitions data { $$ = std::move($1); auto name = $2->name; $$.defs_data[name] = std::move($2); }
| %empty { $$ = definition_group(); }
;
defn
: DEFN LID lowercaseParams EQUAL OCURLY aAdd CCURLY
{ $$ = definition_defn_ptr(
new definition_defn(std::move($2), std::move($3), std::move($6), @$)); }
;
lowercaseParams
: %empty { $$ = std::vector<std::string>(); }
| lowercaseParams LID { $$ = std::move($1); $$.push_back(std::move($2)); }
;
aAdd
: aAdd PLUS aMul { $$ = ast_ptr(new ast_binop(PLUS, std::move($1), std::move($3), @$)); }
| aAdd MINUS aMul { $$ = ast_ptr(new ast_binop(MINUS, std::move($1), std::move($3), @$)); }
| aMul { $$ = std::move($1); }
;
aMul
: aMul TIMES app { $$ = ast_ptr(new ast_binop(TIMES, std::move($1), std::move($3), @$)); }
| aMul DIVIDE app { $$ = ast_ptr(new ast_binop(DIVIDE, std::move($1), std::move($3), @$)); }
| app { $$ = std::move($1); }
;
app
: app appBase { $$ = ast_ptr(new ast_app(std::move($1), std::move($2), @$)); }
| appBase { $$ = std::move($1); }
;
appBase
: INT { $$ = ast_ptr(new ast_int($1, @$)); }
| LID { $$ = ast_ptr(new ast_lid(std::move($1), @$)); }
| UID { $$ = ast_ptr(new ast_uid(std::move($1), @$)); }
| OPAREN aAdd CPAREN { $$ = std::move($2); }
| case { $$ = std::move($1); }
| let { $$ = std::move($1); }
| lambda { $$ = std::move($1); }
;
let
: LET OCURLY definitions CCURLY IN OCURLY aAdd CCURLY
{ $$ = ast_ptr(new ast_let(std::move($3), std::move($7), @$)); }
;
lambda
: BACKSLASH lowercaseParams ARROW OCURLY aAdd CCURLY
{ $$ = ast_ptr(new ast_lambda(std::move($2), std::move($5), @$)); }
;
case
: CASE aAdd OF OCURLY branches CCURLY
{ $$ = ast_ptr(new ast_case(std::move($2), std::move($5), @$)); }
;
branches
: branches branch { $$ = std::move($1); $$.push_back(std::move($2)); }
| branch { $$ = std::vector<branch_ptr>(); $$.push_back(std::move($1));}
;
branch
: pattern ARROW OCURLY aAdd CCURLY
{ $$ = branch_ptr(new branch(std::move($1), std::move($4))); }
;
pattern
: LID { $$ = pattern_ptr(new pattern_var(std::move($1), @$)); }
| UID lowercaseParams
{ $$ = pattern_ptr(new pattern_constr(std::move($1), std::move($2), @$)); }
;
data
: DATA UID lowercaseParams EQUAL OCURLY constructors CCURLY
{ $$ = definition_data_ptr(new definition_data(std::move($2), std::move($3), std::move($6), @$)); }
;
constructors
: constructors COMMA constructor { $$ = std::move($1); $$.push_back(std::move($3)); }
| constructor
{ $$ = std::vector<constructor_ptr>(); $$.push_back(std::move($1)); }
;
constructor
: UID typeList
{ $$ = constructor_ptr(new constructor(std::move($1), std::move($2))); }
;
type
: nonArrowType ARROW type { $$ = parsed_type_ptr(new parsed_type_arr(std::move($1), std::move($3))); }
| nonArrowType { $$ = std::move($1); }
;
nonArrowType
: UID typeList { $$ = parsed_type_ptr(new parsed_type_app(std::move($1), std::move($2))); }
| LID { $$ = parsed_type_ptr(new parsed_type_var(std::move($1))); }
| OPAREN type CPAREN { $$ = std::move($2); }
;
typeListElement
: OPAREN type CPAREN { $$ = std::move($2); }
| UID { $$ = parsed_type_ptr(new parsed_type_app(std::move($1), {})); }
| LID { $$ = parsed_type_ptr(new parsed_type_var(std::move($1))); }
;
typeList
: %empty { $$ = std::vector<parsed_type_ptr>(); }
| typeList typeListElement { $$ = std::move($1); $$.push_back(std::move($2)); }
;

269
code/compiler/13/runtime.c Normal file
View File

@@ -0,0 +1,269 @@
#include <stdint.h>
#include <assert.h>
#include <memory.h>
#include <stdio.h>
#include "runtime.h"
struct node_base* alloc_node() {
struct node_base* new_node = malloc(sizeof(struct node_app));
new_node->gc_next = NULL;
new_node->gc_reachable = 0;
assert(new_node != NULL);
return new_node;
}
struct node_app* alloc_app(struct node_base* l, struct node_base* r) {
struct node_app* node = (struct node_app*) alloc_node();
node->base.tag = NODE_APP;
node->left = l;
node->right = r;
return node;
}
struct node_num* alloc_num(int32_t n) {
struct node_num* node = (struct node_num*) alloc_node();
node->base.tag = NODE_NUM;
node->value = n;
return node;
}
struct node_global* alloc_global(void (*f)(struct gmachine*), int32_t a) {
struct node_global* node = (struct node_global*) alloc_node();
node->base.tag = NODE_GLOBAL;
node->arity = a;
node->function = f;
return node;
}
struct node_ind* alloc_ind(struct node_base* n) {
struct node_ind* node = (struct node_ind*) alloc_node();
node->base.tag = NODE_IND;
node->next = n;
return node;
}
void free_node_direct(struct node_base* n) {
if(n->tag == NODE_DATA) {
free(((struct node_data*) n)->array);
}
}
void gc_visit_node(struct node_base* n) {
if(n->gc_reachable) return;
n->gc_reachable = 1;
if(n->tag == NODE_APP) {
struct node_app* app = (struct node_app*) n;
gc_visit_node(app->left);
gc_visit_node(app->right);
} if(n->tag == NODE_IND) {
struct node_ind* ind = (struct node_ind*) n;
gc_visit_node(ind->next);
} if(n->tag == NODE_DATA) {
struct node_data* data = (struct node_data*) n;
struct node_base** to_visit = data->array;
while(*to_visit) {
gc_visit_node(*to_visit);
to_visit++;
}
}
}
void stack_init(struct stack* s) {
s->size = 4;
s->count = 0;
s->data = malloc(sizeof(*s->data) * s->size);
assert(s->data != NULL);
}
void stack_free(struct stack* s) {
free(s->data);
}
void stack_push(struct stack* s, struct node_base* n) {
while(s->count >= s->size) {
s->data = realloc(s->data, sizeof(*s->data) * (s->size *= 2));
assert(s->data != NULL);
}
s->data[s->count++] = n;
}
struct node_base* stack_pop(struct stack* s) {
assert(s->count > 0);
return s->data[--s->count];
}
struct node_base* stack_peek(struct stack* s, size_t o) {
assert(s->count > o);
return s->data[s->count - o - 1];
}
void stack_popn(struct stack* s, size_t n) {
assert(s->count >= n);
s->count -= n;
}
void gmachine_init(struct gmachine* g) {
stack_init(&g->stack);
g->gc_nodes = NULL;
g->gc_node_count = 0;
g->gc_node_threshold = 128;
}
void gmachine_free(struct gmachine* g) {
stack_free(&g->stack);
struct node_base* to_free = g->gc_nodes;
struct node_base* next;
while(to_free) {
next = to_free->gc_next;
free_node_direct(to_free);
free(to_free);
to_free = next;
}
}
void gmachine_slide(struct gmachine* g, size_t n) {
assert(g->stack.count > n);
g->stack.data[g->stack.count - n - 1] = g->stack.data[g->stack.count - 1];
g->stack.count -= n;
}
void gmachine_update(struct gmachine* g, size_t o) {
assert(g->stack.count > o + 1);
struct node_ind* ind =
(struct node_ind*) g->stack.data[g->stack.count - o - 2];
ind->base.tag = NODE_IND;
ind->next = g->stack.data[g->stack.count -= 1];
}
void gmachine_alloc(struct gmachine* g, size_t o) {
while(o--) {
stack_push(&g->stack,
gmachine_track(g, (struct node_base*) alloc_ind(NULL)));
}
}
void gmachine_pack(struct gmachine* g, size_t n, int8_t t) {
assert(g->stack.count >= n);
struct node_base** data = malloc(sizeof(*data) * (n + 1));
assert(data != NULL);
memcpy(data, &g->stack.data[g->stack.count - n], n * sizeof(*data));
data[n] = NULL;
struct node_data* new_node = (struct node_data*) alloc_node();
new_node->array = data;
new_node->base.tag = NODE_DATA;
new_node->tag = t;
stack_popn(&g->stack, n);
stack_push(&g->stack, gmachine_track(g, (struct node_base*) new_node));
}
void gmachine_split(struct gmachine* g, size_t n) {
struct node_data* node = (struct node_data*) stack_pop(&g->stack);
for(size_t i = 0; i < n; i++) {
stack_push(&g->stack, node->array[i]);
}
}
struct node_base* gmachine_track(struct gmachine* g, struct node_base* b) {
g->gc_node_count++;
b->gc_next = g->gc_nodes;
g->gc_nodes = b;
if(g->gc_node_count >= g->gc_node_threshold) {
uint64_t nodes_before = g->gc_node_count;
gc_visit_node(b);
gmachine_gc(g);
g->gc_node_threshold = g->gc_node_count * 2;
}
return b;
}
void gmachine_gc(struct gmachine* g) {
for(size_t i = 0; i < g->stack.count; i++) {
gc_visit_node(g->stack.data[i]);
}
struct node_base** head_ptr = &g->gc_nodes;
while(*head_ptr) {
if((*head_ptr)->gc_reachable) {
(*head_ptr)->gc_reachable = 0;
head_ptr = &(*head_ptr)->gc_next;
} else {
struct node_base* to_free = *head_ptr;
*head_ptr = to_free->gc_next;
free_node_direct(to_free);
free(to_free);
g->gc_node_count--;
}
}
}
void unwind(struct gmachine* g) {
struct stack* s = &g->stack;
while(1) {
struct node_base* peek = stack_peek(s, 0);
if(peek->tag == NODE_APP) {
struct node_app* n = (struct node_app*) peek;
stack_push(s, n->left);
} else if(peek->tag == NODE_GLOBAL) {
struct node_global* n = (struct node_global*) peek;
assert(s->count > n->arity);
for(size_t i = 1; i <= n->arity; i++) {
s->data[s->count - i]
= ((struct node_app*) s->data[s->count - i - 1])->right;
}
n->function(g);
} else if(peek->tag == NODE_IND) {
struct node_ind* n = (struct node_ind*) peek;
stack_pop(s);
stack_push(s, n->next);
} else {
break;
}
}
}
extern void f_main(struct gmachine* s);
void print_node(struct node_base* n) {
if(n->tag == NODE_APP) {
struct node_app* app = (struct node_app*) n;
print_node(app->left);
putchar(' ');
print_node(app->right);
} else if(n->tag == NODE_DATA) {
printf("(Packed)");
} else if(n->tag == NODE_GLOBAL) {
struct node_global* global = (struct node_global*) n;
printf("(Global: %p)", global->function);
} else if(n->tag == NODE_IND) {
print_node(((struct node_ind*) n)->next);
} else if(n->tag == NODE_NUM) {
struct node_num* num = (struct node_num*) n;
printf("%d", num->value);
}
}
int main(int argc, char** argv) {
struct gmachine gmachine;
struct node_global* first_node = alloc_global(f_main, 0);
struct node_base* result;
gmachine_init(&gmachine);
gmachine_track(&gmachine, (struct node_base*) first_node);
stack_push(&gmachine.stack, (struct node_base*) first_node);
unwind(&gmachine);
result = stack_pop(&gmachine.stack);
printf("Result: ");
print_node(result);
putchar('\n');
gmachine_free(&gmachine);
}

View File

@@ -0,0 +1,84 @@
#pragma once
#include <stdlib.h>
struct gmachine;
enum node_tag {
NODE_APP,
NODE_NUM,
NODE_GLOBAL,
NODE_IND,
NODE_DATA
};
struct node_base {
enum node_tag tag;
int8_t gc_reachable;
struct node_base* gc_next;
};
struct node_app {
struct node_base base;
struct node_base* left;
struct node_base* right;
};
struct node_num {
struct node_base base;
int32_t value;
};
struct node_global {
struct node_base base;
int32_t arity;
void (*function)(struct gmachine*);
};
struct node_ind {
struct node_base base;
struct node_base* next;
};
struct node_data {
struct node_base base;
int8_t tag;
struct node_base** array;
};
struct node_base* alloc_node();
struct node_app* alloc_app(struct node_base* l, struct node_base* r);
struct node_num* alloc_num(int32_t n);
struct node_global* alloc_global(void (*f)(struct gmachine*), int32_t a);
struct node_ind* alloc_ind(struct node_base* n);
void free_node_direct(struct node_base*);
void gc_visit_node(struct node_base*);
struct stack {
size_t size;
size_t count;
struct node_base** data;
};
void stack_init(struct stack* s);
void stack_free(struct stack* s);
void stack_push(struct stack* s, struct node_base* n);
struct node_base* stack_pop(struct stack* s);
struct node_base* stack_peek(struct stack* s, size_t o);
void stack_popn(struct stack* s, size_t n);
struct gmachine {
struct stack stack;
struct node_base* gc_nodes;
int64_t gc_node_count;
int64_t gc_node_threshold;
};
void gmachine_init(struct gmachine* g);
void gmachine_free(struct gmachine* g);
void gmachine_slide(struct gmachine* g, size_t n);
void gmachine_update(struct gmachine* g, size_t o);
void gmachine_alloc(struct gmachine* g, size_t o);
void gmachine_pack(struct gmachine* g, size_t n, int8_t t);
void gmachine_split(struct gmachine* g, size_t n);
struct node_base* gmachine_track(struct gmachine* g, struct node_base* b);
void gmachine_gc(struct gmachine* g);

View File

@@ -0,0 +1,45 @@
%option noyywrap
%option reentrant
%option header-file="scanner.hpp"
%{
#include <iostream>
#include "ast.hpp"
#include "definition.hpp"
#include "parse_driver.hpp"
#include "parser.hpp"
#define YY_USER_ACTION \
drv.get_file_manager().write(yytext, yyleng); \
LOC.step(); LOC.columns(yyleng);
#define LOC drv.get_current_location()
%}
%%
\n { drv.get_current_location().lines(); drv.get_file_manager().mark_line(); }
[ ]+ {}
\\ { return yy::parser::make_BACKSLASH(LOC); }
\+ { return yy::parser::make_PLUS(LOC); }
\* { return yy::parser::make_TIMES(LOC); }
- { return yy::parser::make_MINUS(LOC); }
\/ { return yy::parser::make_DIVIDE(LOC); }
[0-9]+ { return yy::parser::make_INT(atoi(yytext), LOC); }
defn { return yy::parser::make_DEFN(LOC); }
data { return yy::parser::make_DATA(LOC); }
case { return yy::parser::make_CASE(LOC); }
of { return yy::parser::make_OF(LOC); }
let { return yy::parser::make_LET(LOC); }
in { return yy::parser::make_IN(LOC); }
\{ { return yy::parser::make_OCURLY(LOC); }
\} { return yy::parser::make_CCURLY(LOC); }
\( { return yy::parser::make_OPAREN(LOC); }
\) { return yy::parser::make_CPAREN(LOC); }
, { return yy::parser::make_COMMA(LOC); }
-> { return yy::parser::make_ARROW(LOC); }
= { return yy::parser::make_EQUAL(LOC); }
[a-z][a-zA-Z]* { return yy::parser::make_LID(std::string(yytext), LOC); }
[A-Z][a-zA-Z]* { return yy::parser::make_UID(std::string(yytext), LOC); }
<<EOF>> { return yy::parser::make_YYEOF(LOC); }
%%

23
code/compiler/13/test.cpp Normal file
View File

@@ -0,0 +1,23 @@
#include "graph.hpp"
int main() {
function_graph graph;
graph.add_edge("f", "g");
graph.add_edge("g", "h");
graph.add_edge("h", "f");
graph.add_edge("i", "j");
graph.add_edge("j", "i");
graph.add_edge("j", "f");
graph.add_edge("x", "f");
graph.add_edge("x", "i");
for(auto& group : graph.compute_order()) {
std::cout << "Group: " << std::endl;
for(auto& member : group->members) {
std::cout << member << std::endl;
}
}
}

213
code/compiler/13/type.cpp Normal file
View File

@@ -0,0 +1,213 @@
#include "type.hpp"
#include <ostream>
#include <sstream>
#include <algorithm>
#include <vector>
#include "error.hpp"
void type_scheme::print(const type_mgr& mgr, std::ostream& to) const {
if(forall.size() != 0) {
to << "forall ";
for(auto& var : forall) {
to << var << " ";
}
to << ". ";
}
monotype->print(mgr, to);
}
type_ptr type_scheme::instantiate(type_mgr& mgr) const {
if(forall.size() == 0) return monotype;
std::map<std::string, type_ptr> subst;
for(auto& var : forall) {
subst[var] = mgr.new_type();
}
return mgr.substitute(subst, monotype);
}
void type_var::print(const type_mgr& mgr, std::ostream& to) const {
auto type = mgr.lookup(name);
if(type) {
type->print(mgr, to);
} else {
to << name;
}
}
void type_base::print(const type_mgr& mgr, std::ostream& to) const {
to << name;
}
void type_arr::print(const type_mgr& mgr, std::ostream& to) const {
type_var* var;
bool print_parenths = dynamic_cast<type_arr*>(mgr.resolve(left, var).get()) != nullptr;
if(print_parenths) to << "(";
left->print(mgr, to);
if(print_parenths) to << ")";
to << " -> ";
right->print(mgr, to);
}
void type_app::print(const type_mgr& mgr, std::ostream& to) const {
constructor->print(mgr, to);
to << "*";
for(auto& arg : arguments) {
to << " ";
arg->print(mgr, to);
}
}
std::string type_mgr::new_type_name() {
int temp = last_id++;
std::string str = "";
while(temp != -1) {
str += (char) ('a' + (temp % 26));
temp = temp / 26 - 1;
}
std::reverse(str.begin(), str.end());
return str;
}
type_ptr type_mgr::new_type() {
return type_ptr(new type_var(new_type_name()));
}
type_ptr type_mgr::new_arrow_type() {
return type_ptr(new type_arr(new_type(), new_type()));
}
type_ptr type_mgr::lookup(const std::string& var) const {
auto types_it = types.find(var);
if(types_it != types.end()) return types_it->second;
return nullptr;
}
type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const {
type_var* cast;
var = nullptr;
while((cast = dynamic_cast<type_var*>(t.get()))) {
auto it = types.find(cast->name);
if(it == types.end()) {
var = cast;
break;
}
t = it->second;
}
return t;
}
void type_mgr::unify(type_ptr l, type_ptr r, const std::optional<yy::location>& loc) {
type_var *lvar, *rvar;
type_arr *larr, *rarr;
type_base *lid, *rid;
type_app *lapp, *rapp;
l = resolve(l, lvar);
r = resolve(r, rvar);
if(lvar) {
bind(lvar->name, r);
return;
} else if(rvar) {
bind(rvar->name, l);
return;
} else if((larr = dynamic_cast<type_arr*>(l.get())) &&
(rarr = dynamic_cast<type_arr*>(r.get()))) {
unify(larr->left, rarr->left, loc);
unify(larr->right, rarr->right, loc);
return;
} else if((lid = dynamic_cast<type_base*>(l.get())) &&
(rid = dynamic_cast<type_base*>(r.get()))) {
if(lid->name == rid->name &&
lid->arity == rid->arity)
return;
} else if((lapp = dynamic_cast<type_app*>(l.get())) &&
(rapp = dynamic_cast<type_app*>(r.get()))) {
unify(lapp->constructor, rapp->constructor, loc);
auto left_it = lapp->arguments.begin();
auto right_it = rapp->arguments.begin();
while(left_it != lapp->arguments.end() &&
right_it != rapp->arguments.end()) {
unify(*left_it, *right_it, loc);
left_it++, right_it++;
}
return;
}
throw unification_error(l, r, loc);
}
type_ptr type_mgr::substitute(const std::map<std::string, type_ptr>& subst, const type_ptr& t) const {
type_ptr temp = t;
while(type_var* var = dynamic_cast<type_var*>(temp.get())) {
auto subst_it = subst.find(var->name);
if(subst_it != subst.end()) return subst_it->second;
auto var_it = types.find(var->name);
if(var_it == types.end()) return t;
temp = var_it->second;
}
if(type_arr* arr = dynamic_cast<type_arr*>(temp.get())) {
auto left_result = substitute(subst, arr->left);
auto right_result = substitute(subst, arr->right);
if(left_result == arr->left && right_result == arr->right) return t;
return type_ptr(new type_arr(left_result, right_result));
} else if(type_app* app = dynamic_cast<type_app*>(temp.get())) {
auto constructor_result = substitute(subst, app->constructor);
bool arg_changed = false;
std::vector<type_ptr> new_args;
for(auto& arg : app->arguments) {
auto arg_result = substitute(subst, arg);
arg_changed |= arg_result != arg;
new_args.push_back(std::move(arg_result));
}
if(constructor_result == app->constructor && !arg_changed) return t;
type_app* new_app = new type_app(std::move(constructor_result));
std::swap(new_app->arguments, new_args);
return type_ptr(new_app);
}
return t;
}
void type_mgr::bind(const std::string& s, type_ptr t) {
type_var* other = dynamic_cast<type_var*>(t.get());
if(other && other->name == s) return;
types[s] = t;
}
void type_mgr::find_free(const type_ptr& t, std::set<std::string>& into) const {
type_var* var;
type_ptr resolved = resolve(t, var);
if(var) {
into.insert(var->name);
} else if(type_arr* arr = dynamic_cast<type_arr*>(resolved.get())) {
find_free(arr->left, into);
find_free(arr->right, into);
} else if(type_app* app = dynamic_cast<type_app*>(resolved.get())) {
find_free(app->constructor, into);
for(auto& arg : app->arguments) find_free(arg, into);
}
}
void type_mgr::find_free(const type_scheme_ptr& t, std::set<std::string>& into) const {
std::set<std::string> monotype_free;
type_mgr limited_mgr;
for(auto& binding : types) {
auto existing_position = std::find(t->forall.begin(), t->forall.end(), binding.first);
if(existing_position != t->forall.end()) continue;
limited_mgr.types[binding.first] = binding.second;
}
limited_mgr.find_free(t->monotype, monotype_free);
for(auto& not_free : t->forall) {
monotype_free.erase(not_free);
}
into.insert(monotype_free.begin(), monotype_free.end());
}

101
code/compiler/13/type.hpp Normal file
View File

@@ -0,0 +1,101 @@
#pragma once
#include <memory>
#include <map>
#include <string>
#include <vector>
#include <set>
#include <optional>
#include "location.hh"
class type_mgr;
struct type {
virtual ~type() = default;
virtual void print(const type_mgr& mgr, std::ostream& to) const = 0;
};
using type_ptr = std::shared_ptr<type>;
struct type_scheme {
std::vector<std::string> forall;
type_ptr monotype;
type_scheme(type_ptr type) : forall(), monotype(std::move(type)) {}
void print(const type_mgr& mgr, std::ostream& to) const;
type_ptr instantiate(type_mgr& mgr) const;
};
using type_scheme_ptr = std::shared_ptr<type_scheme>;
struct type_var : public type {
std::string name;
type_var(std::string n)
: name(std::move(n)) {}
void print(const type_mgr& mgr, std::ostream& to) const;
};
struct type_base : public type {
std::string name;
int32_t arity;
type_base(std::string n, int32_t a = 0)
: name(std::move(n)), arity(a) {}
void print(const type_mgr& mgr, std::ostream& to) const;
};
struct type_data : public type_base {
struct constructor {
int tag;
};
std::map<std::string, constructor> constructors;
type_data(std::string n, int32_t a = 0)
: type_base(std::move(n), a) {}
};
struct type_arr : public type {
type_ptr left;
type_ptr right;
type_arr(type_ptr l, type_ptr r)
: left(std::move(l)), right(std::move(r)) {}
void print(const type_mgr& mgr, std::ostream& to) const;
};
struct type_app : public type {
type_ptr constructor;
std::vector<type_ptr> arguments;
type_app(type_ptr c)
: constructor(std::move(c)) {}
void print(const type_mgr& mgr, std::ostream& to) const;
};
class type_mgr {
private:
int last_id = 0;
std::map<std::string, type_ptr> types;
public:
std::string new_type_name();
type_ptr new_type();
type_ptr new_arrow_type();
void unify(type_ptr l, type_ptr r, const std::optional<yy::location>& loc = std::nullopt);
type_ptr substitute(
const std::map<std::string, type_ptr>& subst,
const type_ptr& t) const;
type_ptr lookup(const std::string& var) const;
type_ptr resolve(type_ptr t, type_var*& var) const;
void bind(const std::string& s, type_ptr t);
void find_free(const type_ptr& t, std::set<std::string>& into) const;
void find_free(const type_scheme_ptr& t, std::set<std::string>& into) const;
};

View File

@@ -0,0 +1,96 @@
#include "type_env.hpp"
#include "type.hpp"
#include "error.hpp"
#include <cassert>
void type_env::find_free(const type_mgr& mgr, std::set<std::string>& into) const {
if(parent != nullptr) parent->find_free(mgr, into);
for(auto& binding : names) {
mgr.find_free(binding.second.type, into);
}
}
void type_env::find_free_except(const type_mgr& mgr, const group& avoid,
std::set<std::string>& into) const {
if(parent != nullptr) parent->find_free(mgr, into);
for(auto& binding : names) {
if(avoid.members.find(binding.first) != avoid.members.end()) continue;
mgr.find_free(binding.second.type, into);
}
}
type_scheme_ptr type_env::lookup(const std::string& name) const {
auto it = names.find(name);
if(it != names.end()) return it->second.type;
if(parent) return parent->lookup(name);
return nullptr;
}
bool type_env::is_global(const std::string& name) const {
auto it = names.find(name);
if(it != names.end()) return it->second.vis == visibility::global;
if(parent) return parent->is_global(name);
return false;
}
void type_env::set_mangled_name(const std::string& name, const std::string& mangled) {
auto it = names.find(name);
// Can't set mangled name for non-existent variable.
assert(it != names.end());
// Local names shouldn't need mangling.
assert(it->second.vis == visibility::global);
it->second.mangled_name = mangled;
}
const std::string& type_env::get_mangled_name(const std::string& name) const {
auto it = names.find(name);
if(it != names.end()) {
assert(it->second.mangled_name);
return *it->second.mangled_name;
}
assert(parent != nullptr);
return parent->get_mangled_name(name);
}
type_ptr type_env::lookup_type(const std::string& name) const {
auto it = type_names.find(name);
if(it != type_names.end()) return it->second;
if(parent) return parent->lookup_type(name);
return nullptr;
}
void type_env::bind(const std::string& name, type_ptr t, visibility v) {
type_scheme_ptr new_scheme(new type_scheme(std::move(t)));
names[name] = variable_data(std::move(new_scheme), v, std::nullopt);
}
void type_env::bind(const std::string& name, type_scheme_ptr t, visibility v) {
names[name] = variable_data(std::move(t), v, "");
}
void type_env::bind_type(const std::string& type_name, type_ptr t) {
if(lookup_type(type_name) != nullptr)
throw type_error("redefinition of type");
type_names[type_name] = t;
}
void type_env::generalize(const std::string& name, const group& grp, type_mgr& mgr) {
auto names_it = names.find(name);
assert(names_it != names.end());
assert(names_it->second.type->forall.size() == 0);
std::set<std::string> free_in_type;
std::set<std::string> free_in_env;
mgr.find_free(names_it->second.type->monotype, free_in_type);
find_free_except(mgr, grp, free_in_env);
for(auto& free : free_in_type) {
if(free_in_env.find(free) != free_in_env.end()) continue;
names_it->second.type->forall.push_back(free);
}
}
type_env_ptr type_scope(type_env_ptr parent) {
return type_env_ptr(new type_env(std::move(parent)));
}

View File

@@ -0,0 +1,52 @@
#pragma once
#include <map>
#include <string>
#include <set>
#include <optional>
#include "graph.hpp"
#include "type.hpp"
struct type_env;
using type_env_ptr = std::shared_ptr<type_env>;
enum class visibility { global,local };
class type_env {
private:
struct variable_data {
type_scheme_ptr type;
visibility vis;
std::optional<std::string> mangled_name;
variable_data()
: variable_data(nullptr, visibility::local, std::nullopt) {}
variable_data(type_scheme_ptr t, visibility v, std::optional<std::string> n)
: type(std::move(t)), vis(v), mangled_name(std::move(n)) {}
};
type_env_ptr parent;
std::map<std::string, variable_data> names;
std::map<std::string, type_ptr> type_names;
public:
type_env(type_env_ptr p) : parent(std::move(p)) {}
type_env() : type_env(nullptr) {}
void find_free(const type_mgr& mgr, std::set<std::string>& into) const;
void find_free_except(const type_mgr& mgr, const group& avoid,
std::set<std::string>& into) const;
type_scheme_ptr lookup(const std::string& name) const;
bool is_global(const std::string& name) const;
void set_mangled_name(const std::string& name, const std::string& mangled);
const std::string& get_mangled_name(const std::string& name) const;
type_ptr lookup_type(const std::string& name) const;
void bind(const std::string& name, type_ptr t,
visibility v = visibility::local);
void bind(const std::string& name, type_scheme_ptr t,
visibility v = visibility::local);
void bind_type(const std::string& type_name, type_ptr t);
void generalize(const std::string& name, const group& grp, type_mgr& mgr);
};
type_env_ptr type_scope(type_env_ptr parent);

View File

@@ -0,0 +1,21 @@
takeUntilMax :: [Int] -> Int -> (Int, [Int])
takeUntilMax [] m = (m, [])
takeUntilMax [x] _ = (x, [x])
takeUntilMax (x:xs) m
| x == m = (x, [x])
| otherwise =
let (m', xs') = takeUntilMax xs m
in (max m' x, x:xs')
doTakeUntilMax :: [Int] -> [Int]
doTakeUntilMax l = l'
where (m, l') = takeUntilMax l m
takeUntilMax' :: [Int] -> Int -> (Int, [Int])
takeUntilMax' [] m = (m, [])
takeUntilMax' [x] _ = (x, [x])
takeUntilMax' (x:xs) m
| x == m = (maximum (x:xs), [x])
| otherwise =
let (m', xs') = takeUntilMax' xs m
in (max m' x, x:xs')

View File

@@ -0,0 +1,28 @@
import Data.Map as Map
import Data.Maybe
import Control.Applicative
data Element = A | B | C | D
deriving (Eq, Ord, Show)
addElement :: Element -> Map Element Int -> Map Element Int
addElement = alter ((<|> Just 1) . fmap (+1))
getScore :: Element -> Map Element Int -> Float
getScore e m = fromMaybe 1.0 $ ((1.0/) . fromIntegral) <$> Map.lookup e m
data BinaryTree a = Empty | Node a (BinaryTree a) (BinaryTree a) deriving Show
type ElementTree = BinaryTree Element
type ScoredElementTree = BinaryTree (Element, Float)
assignScores :: ElementTree -> Map Element Int -> (Map Element Int, ScoredElementTree)
assignScores Empty m = (Map.empty, Empty)
assignScores (Node e t1 t2) m = (m', Node (e, getScore e m) t1' t2')
where
(m1, t1') = assignScores t1 m
(m2, t2') = assignScores t2 m
m' = addElement e $ unionWith (+) m1 m2
doAssignScores :: ElementTree -> ScoredElementTree
doAssignScores t = t'
where (m, t') = assignScores t m

View File

@@ -0,0 +1,102 @@
data Reg = A | B | R
data Ty = IntTy | BoolTy
TypeState : Type
TypeState = (Ty, Ty, Ty)
getRegTy : Reg -> TypeState -> Ty
getRegTy A (a, _, _) = a
getRegTy B (_, b, _) = b
getRegTy R (_, _, r) = r
setRegTy : Reg -> Ty -> TypeState -> TypeState
setRegTy A a (_, b, r) = (a, b, r)
setRegTy B b (a, _, r) = (a, b, r)
setRegTy R r (a, b, _) = (a, b, r)
data Expr : TypeState -> Ty -> Type where
Lit : Int -> Expr s IntTy
Load : (r : Reg) -> Expr s (getRegTy r s)
Add : Expr s IntTy -> Expr s IntTy -> Expr s IntTy
Leq : Expr s IntTy -> Expr s IntTy -> Expr s BoolTy
Not : Expr s BoolTy -> Expr s BoolTy
mutual
data Stmt : TypeState -> TypeState -> TypeState -> Type where
Store : (r : Reg) -> Expr s t -> Stmt l s (setRegTy r t s)
If : Expr s BoolTy -> Prog l s n -> Prog l s n -> Stmt l s n
Loop : Prog s s s -> Stmt l s s
Break : Stmt s s s
data Prog : TypeState -> TypeState -> TypeState -> Type where
Nil : Prog l s s
(::) : Stmt l s n -> Prog l n m -> Prog l s m
initialState : TypeState
initialState = (IntTy, IntTy, IntTy)
testProg : Prog Main.initialState Main.initialState Main.initialState
testProg =
[ Store A (Lit 1 `Leq` Lit 2)
, If (Load A)
[ Store A (Lit 1) ]
[ Store A (Lit 2) ]
, Store B (Lit 2)
, Store R (Add (Load A) (Load B))
]
prodProg : Prog Main.initialState Main.initialState Main.initialState
prodProg =
[ Store A (Lit 7)
, Store B (Lit 9)
, Store R (Lit 0)
, Loop
[ If (Load A `Leq` Lit 0)
[ Break ]
[ Store R (Load R `Add` Load B)
, Store A (Load A `Add` Lit (-1))
]
]
]
repr : Ty -> Type
repr IntTy = Int
repr BoolTy = Bool
data State : TypeState -> Type where
MkState : (repr a, repr b, repr c) -> State (a, b, c)
getReg : (r : Reg) -> State s -> repr (getRegTy r s)
getReg A (MkState (a, _, _)) = a
getReg B (MkState (_, b, _)) = b
getReg R (MkState (_, _, r)) = r
setReg : (r : Reg) -> repr t -> State s -> State (setRegTy r t s)
setReg A a (MkState (_, b, r)) = MkState (a, b, r)
setReg B b (MkState (a, _, r)) = MkState (a, b, r)
setReg R r (MkState (a, b, _)) = MkState (a, b, r)
expr : Expr s t -> State s -> repr t
expr (Lit i) _ = i
expr (Load r) s = getReg r s
expr (Add l r) s = expr l s + expr r s
expr (Leq l r) s = expr l s <= expr r s
expr (Not e) s = not $ expr e s
mutual
stmt : Stmt l s n -> State s -> Either (State l) (State n)
stmt (Store r e) s = Right $ setReg r (expr e s) s
stmt (If c t e) s = if expr c s then prog t s else prog e s
stmt (Loop p) s =
case prog p s >>= stmt (Loop p) of
Right s => Right s
Left s => Right s
stmt Break s = Left s
prog : Prog l s n -> State s -> Either (State l) (State n)
prog Nil s = Right s
prog (st::p) s = stmt st s >>= prog p
run : Prog l s l -> State s -> State l
run p s = either id id $ prog p s

View File

@@ -0,0 +1,99 @@
data ExprType
= IntType
| BoolType
| StringType
repr : ExprType -> Type
repr IntType = Int
repr BoolType = Bool
repr StringType = String
intBoolImpossible : IntType = BoolType -> Void
intBoolImpossible Refl impossible
intStringImpossible : IntType = StringType -> Void
intStringImpossible Refl impossible
boolStringImpossible : BoolType = StringType -> Void
boolStringImpossible Refl impossible
decEq : (a : ExprType) -> (b : ExprType) -> Dec (a = b)
decEq IntType IntType = Yes Refl
decEq BoolType BoolType = Yes Refl
decEq StringType StringType = Yes Refl
decEq IntType BoolType = No intBoolImpossible
decEq BoolType IntType = No $ intBoolImpossible . sym
decEq IntType StringType = No intStringImpossible
decEq StringType IntType = No $ intStringImpossible . sym
decEq BoolType StringType = No boolStringImpossible
decEq StringType BoolType = No $ boolStringImpossible . sym
data Op
= Add
| Subtract
| Multiply
| Divide
data Expr
= IntLit Int
| BoolLit Bool
| StringLit String
| BinOp Op Expr Expr
| IfElse Expr Expr Expr
data SafeExpr : ExprType -> Type where
IntLiteral : Int -> SafeExpr IntType
BoolLiteral : Bool -> SafeExpr BoolType
StringLiteral : String -> SafeExpr StringType
BinOperation : (repr a -> repr b -> repr c) -> SafeExpr a -> SafeExpr b -> SafeExpr c
IfThenElse : SafeExpr BoolType -> SafeExpr t -> SafeExpr t -> SafeExpr t
typecheckOp : Op -> (a : ExprType) -> (b : ExprType) -> Either String (c : ExprType ** repr a -> repr b -> repr c)
typecheckOp Add IntType IntType = Right (IntType ** (+))
typecheckOp Subtract IntType IntType = Right (IntType ** (-))
typecheckOp Multiply IntType IntType = Right (IntType ** (*))
typecheckOp Divide IntType IntType = Right (IntType ** div)
typecheckOp _ _ _ = Left "Invalid binary operator application"
requireBool : (n : ExprType ** SafeExpr n) -> Either String (SafeExpr BoolType)
requireBool (BoolType ** e) = Right e
requireBool _ = Left "Not a boolean."
typecheck : Expr -> Either String (n : ExprType ** SafeExpr n)
typecheck (IntLit i) = Right (_ ** IntLiteral i)
typecheck (BoolLit b) = Right (_ ** BoolLiteral b)
typecheck (StringLit s) = Right (_ ** StringLiteral s)
typecheck (BinOp o l r) = do
(lt ** le) <- typecheck l
(rt ** re) <- typecheck r
(ot ** f) <- typecheckOp o lt rt
pure (_ ** BinOperation f le re)
typecheck (IfElse c t e) =
do
ce <- typecheck c >>= requireBool
(tt ** te) <- typecheck t
(et ** ee) <- typecheck e
case decEq tt et of
Yes p => pure (_ ** IfThenElse ce (replace p te) ee)
No _ => Left "Incompatible branch types."
eval : SafeExpr t -> repr t
eval (IntLiteral i) = i
eval (BoolLiteral b) = b
eval (StringLiteral s) = s
eval (BinOperation f l r) = f (eval l) (eval r)
eval (IfThenElse c t e) = if (eval c) then (eval t) else (eval e)
resultStr : {t : ExprType} -> repr t -> String
resultStr {t=IntType} i = show i
resultStr {t=BoolType} b = show b
resultStr {t=StringType} s = show s
tryEval : Expr -> String
tryEval ex =
case typecheck ex of
Left err => "Type error: " ++ err
Right (t ** e) => resultStr $ eval {t} e
main : IO ()
main = putStrLn $ tryEval $ BinOp Add (IfElse (BoolLit True) (IntLit 6) (IntLit 7)) (BinOp Multiply (IntLit 160) (IntLit 2))

View File

@@ -0,0 +1,120 @@
data ExprType
= IntType
| BoolType
| StringType
| PairType ExprType ExprType
repr : ExprType -> Type
repr IntType = Int
repr BoolType = Bool
repr StringType = String
repr (PairType t1 t2) = Pair (repr t1) (repr t2)
decEq : (a : ExprType) -> (b : ExprType) -> Maybe (a = b)
decEq IntType IntType = Just Refl
decEq BoolType BoolType = Just Refl
decEq StringType StringType = Just Refl
decEq (PairType lt1 lt2) (PairType rt1 rt2) = do
subEq1 <- decEq lt1 rt1
subEq2 <- decEq lt2 rt2
let firstEqual = replace {P = \t1 => PairType lt1 lt2 = PairType t1 lt2} subEq1 Refl
let secondEqual = replace {P = \t2 => PairType lt1 lt2 = PairType rt1 t2} subEq2 firstEqual
pure secondEqual
decEq _ _ = Nothing
data Op
= Add
| Subtract
| Multiply
| Divide
data Expr
= IntLit Int
| BoolLit Bool
| StringLit String
| BinOp Op Expr Expr
| IfElse Expr Expr Expr
| Pair Expr Expr
| Fst Expr
| Snd Expr
data SafeExpr : ExprType -> Type where
IntLiteral : Int -> SafeExpr IntType
BoolLiteral : Bool -> SafeExpr BoolType
StringLiteral : String -> SafeExpr StringType
BinOperation : (repr a -> repr b -> repr c) -> SafeExpr a -> SafeExpr b -> SafeExpr c
IfThenElse : SafeExpr BoolType -> SafeExpr t -> SafeExpr t -> SafeExpr t
NewPair : SafeExpr t1 -> SafeExpr t2 -> SafeExpr (PairType t1 t2)
First : SafeExpr (PairType t1 t2) -> SafeExpr t1
Second : SafeExpr (PairType t1 t2) -> SafeExpr t2
typecheckOp : Op -> (a : ExprType) -> (b : ExprType) -> Either String (c : ExprType ** repr a -> repr b -> repr c)
typecheckOp Add IntType IntType = Right (IntType ** (+))
typecheckOp Subtract IntType IntType = Right (IntType ** (-))
typecheckOp Multiply IntType IntType = Right (IntType ** (*))
typecheckOp Divide IntType IntType = Right (IntType ** div)
typecheckOp _ _ _ = Left "Invalid binary operator application"
requireBool : (n : ExprType ** SafeExpr n) -> Either String (SafeExpr BoolType)
requireBool (BoolType ** e) = Right e
requireBool _ = Left "Not a boolean."
typecheck : Expr -> Either String (n : ExprType ** SafeExpr n)
typecheck (IntLit i) = Right (_ ** IntLiteral i)
typecheck (BoolLit b) = Right (_ ** BoolLiteral b)
typecheck (StringLit s) = Right (_ ** StringLiteral s)
typecheck (BinOp o l r) = do
(lt ** le) <- typecheck l
(rt ** re) <- typecheck r
(ot ** f) <- typecheckOp o lt rt
pure (_ ** BinOperation f le re)
typecheck (IfElse c t e) =
do
ce <- typecheck c >>= requireBool
(tt ** te) <- typecheck t
(et ** ee) <- typecheck e
case decEq tt et of
Just p => pure (_ ** IfThenElse ce (replace p te) ee)
Nothing => Left "Incompatible branch types."
typecheck (Pair l r) =
do
(lt ** le) <- typecheck l
(rt ** re) <- typecheck r
pure (_ ** NewPair le re)
typecheck (Fst p) =
do
(pt ** pe) <- typecheck p
case pt of
PairType _ _ => pure $ (_ ** First pe)
_ => Left "Applying fst to non-pair."
typecheck (Snd p) =
do
(pt ** pe) <- typecheck p
case pt of
PairType _ _ => pure $ (_ ** Second pe)
_ => Left "Applying snd to non-pair."
eval : SafeExpr t -> repr t
eval (IntLiteral i) = i
eval (BoolLiteral b) = b
eval (StringLiteral s) = s
eval (BinOperation f l r) = f (eval l) (eval r)
eval (IfThenElse c t e) = if (eval c) then (eval t) else (eval e)
eval (NewPair l r) = (eval l, eval r)
eval (First p) = fst (eval p)
eval (Second p) = snd (eval p)
resultStr : {t : ExprType} -> repr t -> String
resultStr {t=IntType} i = show i
resultStr {t=BoolType} b = show b
resultStr {t=StringType} s = show s
resultStr {t=PairType t1 t2} (l,r) = "(" ++ resultStr l ++ ", " ++ resultStr r ++ ")"
tryEval : Expr -> String
tryEval ex =
case typecheck ex of
Left err => "Type error: " ++ err
Right (t ** e) => resultStr $ eval {t} e
main : IO ()
main = putStrLn $ tryEval $ BinOp Add (Fst (IfElse (BoolLit True) (Pair (IntLit 6) (BoolLit True)) (Pair (IntLit 7) (BoolLit False)))) (BinOp Multiply (IntLit 160) (IntLit 2))

View File

@@ -3,5 +3,20 @@ languageCode = "en-us"
title = "Daniel's Blog" title = "Daniel's Blog"
theme = "vanilla" theme = "vanilla"
pygmentsCodeFences = true pygmentsCodeFences = true
pygmentsStyle = "github" pygmentsUseClasses = true
summaryLength = 20 summaryLength = 20
[outputFormats]
[outputFormats.Toml]
name = "toml"
mediaType = "application/toml"
isHTML = false
[outputs]
home = ["html","rss","toml"]
[markup]
[markup.tableOfContents]
endLevel = 4
ordered = false
startLevel = 3

View File

@@ -1,8 +1,8 @@
--- ---
title: About title: About
--- ---
I'm Daniel, a Computer Science student currently in my third (and final) undergraduate year at Oregon State University. I'm Daniel, a Computer Science student currently working towards my Master's Degree at Oregon State University.
Due my initial interest in calculators and compilers, I got involved in the Programming Language Theory research Due to my initial interest in calculators and compilers, I got involved in the Programming Language Theory research
group, gaining same experience in formal verification, domain specific language, and explainable computing. group, gaining same experience in formal verification, domain specific language, and explainable computing.
For work, school, and hobby projects, I use a variety of programming languages, most commonly C/C++, For work, school, and hobby projects, I use a variety of programming languages, most commonly C/C++,

350
content/blog/00_aoc_coq.md Normal file
View File

@@ -0,0 +1,350 @@
---
title: "Advent of Code in Coq - Day 1"
date: 2020-12-02T18:44:56-08:00
tags: ["Advent of Code", "Coq"]
---
The first puzzle of this year's [Advent of Code](https://adventofcode.com) was quite
simple, which gave me a thought: "Hey, this feels within reach for me to formally verify!"
At first, I wanted to formalize and prove the correctness of the [two-pointer solution](https://www.geeksforgeeks.org/two-pointers-technique/).
However, I didn't have the time to mess around with the various properties of sorted
lists and their traversals. So, I settled for the brute force solution. Despite
the simplicity of its implementation, there is plenty to talk about when proving
its correctness using Coq. Let's get right into it!
Before we start, in the interest of keeping the post self-contained, here's the (paraphrased)
problem statement:
> Given an unsorted list of numbers, find two distinct numbers that add up to 2020.
With this in mind, we can move on to writing some Coq!
### Defining the Functions
The first step to proving our code correct is to actually write the code! To start with,
let's write a helper function that, given a number `x`, tries to find another number
`y` such that `x + y = 2020`. In fact, rather than hardcoding the desired
sum to `2020`, let's just use another argument called `total`. The code is quite simple:
{{< codelines "Coq" "aoc-coq/day1.v" 7 14 >}}
Here, `is` is the list of numbers that we want to search.
We proceed by case analysis: if the list is empty, we can't
find a match, so we return `None` (the Coq equivalent of Haskell's `Nothing`).
On the other hand, if the list has at least one element `y`, we see if it adds
up to `total`, and return `Some y` (equivalent to `Just y` in Haskell) if it does.
If it doesn't, we continue our search into the rest of the list.
It's somewhat unusual, in my experience, to put the list argument first when writing
functions in a language with [currying](https://wiki.haskell.org/Currying). However,
it seems as though Coq's `simpl` tactic, which we will use later, works better
for our purposes when the argument being case analyzed is given first.
We can now use `find_matching` to define our `find_sum` function, which solves part 1.
Here's the code:
{{< codelines "Coq" "aoc-coq/day1.v" 16 24 >}}
For every `x` that we encounter in our input list `is`, we want to check if there's
a matching number in the rest of the list. We only search the remainder of the list
because we can't use `x` twice: the `x` and `y` we return that add up to `total`
must be different elements. We use `find_matching` to try find a complementary number
for `x`. If we don't find it, this `x` isn't it, so we recursively move on to `xs`.
On the other hand, if we _do_ find a matching `y`, we're done! We return `(x,y)`,
wrapped in `Some` to indicate that we found something useful.
What about that `(* Was buggy! *)` line? Well, it so happens that my initial
implementation had a bug on this line, one that came up as I was proving
the correctness of my function. When I wasn't able to prove a particular
behavior in one of the cases, I realized something was wrong. In short,
my proof actually helped me find and fix a bug!
This is all the code we'll need to get our solution. Next, let's talk about some
properties of our two functions.
### Our First Lemma
When we call `find_matching`, we want to be sure that if we get a number,
it does indeed add up to our expected total. We can state it a little bit more
formally as follows:
> For any numbers `k` and `x`, and for any list of number `is`,
> if `find_matching is k x` returns a number `y`, then `x + y = k`.
And this is how we write it in Coq:
{{< codelines "Coq" "aoc-coq/day1.v" 26 27 >}}
The arrow, `->`, reads "implies". Other than that, I think this
property reads pretty well. The proof, unfortunately, is a little bit more involved.
Here are the first few lines:
{{< codelines "Coq" "aoc-coq/day1.v" 28 31 >}}
We start with the `intros is` tactic, which is akin to saying
"consider a particular list of integers `is`". We do this without losing
generality: by simply examining a concrete list, we've said nothing about
what that list is like. We then proceed by induction on `is`.
To prove something by induction for a list, we need to prove two things:
* The __base case__. Whatever property we want to hold, it must
hold for the empty list, which is the simplest possible list.
In our case, this means `find_matching` searching an empty list.
* The __inductive case__. Assuming that a property holds for any list
`[b, c, ...]`, we want to show that the property also holds for
the list `[a, b, c, ...]`. That is, the property must remain true if we
prepend an element to a list for which this property holds.
These two things combined give us a proof for _all_ lists, which is exactly
what we want! If you don't belive me, here's how it works. Suppose you want
to prove that some property `P` holds for `[1,2,3,4]`. Given the base
case, we know that `P []` holds. Next, by the inductive case, since
`P []` holds, we can prepend `4` to the list, and the property will
still hold. Thus, `P [4]`. Now that `P [4]` holds, we can again prepend
an element to the list, this time a `3`, and conclude that `P [3,4]`.
Repeating this twice more, we arrive at our desired fact: `P [1,2,3,4]`.
When we write `induction is`, Coq will generate two proof goals for us,
one for the base case, and one for the inductive case. We will have to prove
each of them separately. Since we have
not yet introduced the variables `k`, `x`, and `y`, they remain
inside a `forall` quantifier at that time. To be able to refer
to them, we want to use `intros`. We want to do this in both the
base and the inductive case. To quickly do this, we use Coq's `;`
operator. When we write `a; b`, Coq runs the tactic `a`, and then
runs the tactic `b` in every proof goal generated by `a`. This is
exactly what we want.
There's one more variable inside our second `intros`: `Hev`.
This variable refers to the hypothesis of our statement:
that is, the part on the left of the `->`. To prove that `A`
implies `B`, we assume that `A` holds, and try to argue `B` from there.
Here is no different: when we use `intros Hev`, we say, "suppose that you have
a proof that `find_matching` evaluates to `Some y`, called `Hev`". The thing
on the right of `->` becomes our proof goal.
Now, it's time to look at the cases. To focus on one case at a time,
we use `-`. The first case is our base case. Here's what Coq prints
out at this time:
```
k, x, y : nat
Hev : find_matching nil k x = Some y
========================= (1 / 1)
x + y = k
```
All the stuff above the `===` line are our hypotheses. We know
that we have some `k`, `x`, and `y`, all of which are numbers.
We also have the assumption that `find_matching` returns `Some y`.
In the base case, `is` is just `[]`, and this is reflected in the
type for `Hev`. To make this more clear, we'll simplify the call to `find_matching`
in `Hev`, using `simpl in Hev`. Now, here's what Coq has to say about `Hev`:
```
Hev : None = Some y
```
Well, this doesn't make any sense. How can something be equal to nothing?
We ask Coq this question using `inversion Hev`. Effectively, the question
that `inversion` asks is: what are the possible ways we could have acquired `Hev`?
Coq generates a proof goal for each of these possible ways. Alas, there are
no ways to arrive at this contradictory assumption: the number of proof sub-goals
is zero. This means we're done with the base case!
The inductive case is the meat of this proof. Here's the corresponding part
of the Coq source file:
{{< codelines "Coq" "aoc-coq/day1.v" 32 36 >}}
This time, the proof state is more complicated:
```
a : nat
is : list nat
IHis : forall k x y : nat, find_matching is k x = Some y -> x + y = k
k, x, y : nat
Hev : find_matching (a :: is) k x = Some y
========================= (1 / 1)
x + y = k
```
Following the footsteps of our informal description of the inductive case,
Coq has us prove our property for `(a :: is)`, or the list `is` to which
`a` is being prepended. Like before, we assume that our property holds for `is`.
This is represented in the __induction hypothesis__ `IHis`. It states that if
`find_matching` finds a `y` in `is`, it must add up to `k`. However, `IHis`
doesn't tell us anything about `a :: is`: that's our job. We also still have
`Hev`, which is our assumption that `find_matching` finds a `y` in `(a :: is)`.
Running `simpl in Hev` gives us the following:
```
Hev : (if x + a =? k then Some a else find_matching is k x) = Some y
```
The result of `find_matching` now depends on whether or not the new element `a`
adds up to `k`. If it does, then `find_matching` will return `a`, which means
that `y` is the same as `a`. If not, it must be that `find_matching` finds
the `y` in the rest of the list, `is`. We're not sure which of the possibilities
is the case. Fortunately, we don't need to be!
If we can prove that the `y` that `find_matching` finds is correct regardless
of whether `a` adds up to `k` or not, we're good to go! To do this,
we perform case analysis using `destruct`.
Our particular use of `destruct` says: check any possible value for `x + a ?= k`,
and create an equation `Heq` that tells us what that value is. `?=` returns a boolean
value, and so `destruct` generates two new goals: one where the function returns `true`,
and one where it returns `false`. We start with the former. Here's the proof state:
```
a : nat
is : list nat
IHis : forall k x y : nat, find_matching is k x = Some y -> x + y = k
k, x, y : nat
Heq : (x + a =? k) = true
Hev : Some a = Some y
========================= (1 / 1)
x + y = k
```
There is a new hypothesis: `Heq`. It tells us that we're currently
considering the case where `?=` evaluates to `true`. Also,
`Hev` has been considerably simplified: now that we know the condition
of the `if` expression, we can just replace it with the `then` branch.
Looking at `Hev`, we can see that our prediction was right: `a` is equal to `y`. After all,
if they weren't, `Some a` wouldn't equal to `Some y`. To make Coq
take this information into account, we use `injection`. This will create
a new hypothesis, `a = y`. But if one is equal to the other, why don't we
just use only one of these variables everywhere? We do exactly that by using
`subst`, which replaces `a` with `y` everywhere in our proof.
The proof state is now:
```
is : list nat
IHis : forall k x y : nat, find_matching is k x = Some y -> x + y = k
k, x, y : nat
Heq : (x + y =? k) = true
========================= (1 / 1)
x + y = k
```
We're close, but there's one more detail to keep in mind. Our goal, `x + y = k`,
is the __proposition__ that `x + y` is equal to `k`. However, `Heq` tells us
that the __function__ `?=` evaluates to `true`. These are fundamentally different.
One talks about mathematical equality, while the other about some function `?=`
defined somewhere in Coq's standard library. Who knows - maybe there's a bug in
Coq's implementation! Fortunately, Coq comes with a proof that if two numbers
are equal according to `?=`, they are mathematically equal. This proof is
called `eqb_nat_eq`. We tell Coq to use this with `apply`. Our proof goal changes to:
```
true = (x + y =? k)
```
This is _almost_ like `Heq`, but flipped. Instead of manually flipping it and using `apply`
with `Heq`, I let Coq do the rest of the work using `auto`.
Phew! All this for the `true` case of `?=`. Next, what happens if `x + a` does not equal `k`?
Here's the proof state at this time:
```
a : nat
is : list nat
IHis : forall k x y : nat, find_matching is k x = Some y -> x + y = k
k, x, y : nat
Heq : (x + a =? k) = false
Hev : find_matching is k x = Some y
========================= (1 / 1)
x + y = k
```
Since `a` was not what it was looking for, `find_matching` moved on to `is`. But hey,
we're in the inductive case! We are assuming that `find_matching` will work properly
with the list `is`. Since `find_matching` found its `y` in `is`, this should be all we need!
We use our induction hypothesis `IHis` with `apply`. `IHis` itself does not know that
`find_matching` moved on to `is`, so it asks us to prove it. Fortunately, `Hev` tells us
exactly that, so we use `assumption`, and the proof is complete! Quod erat demonstrandum, QED!
### The Rest of the Owl
Here are a couple of other properties of `find_matching`. For brevity's sake, I will
not go through their proofs step-by-step. I find that the best way to understand
Coq proofs is to actually step through them in the IDE!
First on the list is `find_matching_skip`. Here's the type:
{{< codelines "Coq" "aoc-coq/day1.v" 38 39 >}}
It reads: if we correctly find a number in a small list `is`, we can find that same number
even if another number is prepended to `is`. That makes sense: _adding_ a number to
a list doesn't remove whatever we found in it! I used this lemma to prove another,
`find_matching_works`:
{{< codelines "Coq" "aoc-coq/day1.v" 49 50 >}}
This reads, if there _is_ an element `y` in `is` that adds up to `k` with `x`, then
`find_matching` will find it. This is an important property. After all, if it didn't
hold, it would mean that `find_matching` would occasionally fail to find a matching
number, even though it's there! We can't have that.
Finally, we want to specify what it means for `find_sum`, our solution function, to actually
work. The naive definition would be:
> Given a list of integers, `find_sum` always finds a pair of numbers that add up to `k`.
Unfortunately, this is not true. What if, for instance, we give `find_sum` an empty list?
There are no numbers from that list to find and add together. Even a non-empty list
may not include such a pair! We need a way to characterize valid input lists. I claim
that all lists from this Advent of Code puzzle are guaranteed to have two numbers that
add up to our goal, and that these numbers are not equal to each other. In Coq,
we state this as follows:
{{< codelines "Coq" "aoc-coq/day1.v" 4 5 >}}
This defines a new property, `has_pair t is` (read "`is` has a pair of numbers that add to `t`"),
which means:
> There are two numbers `n1` and `n2` such that, they are not equal to each other (`n1<>n2`) __and__
> the number `n1` is an element of `is` (`In n1 is`) __and__
> the number `n2` is an element of `is` (`In n2 is`) __and__
> the two numbers add up to `t` (`n1 + n2 = t`).
When making claims about the correctness of our algorithm, we will assume that this
property holds. Finally, here's the theorem we want to prove:
{{< codelines "Coq" "aoc-coq/day1.v" 64 66 >}}
It reads, "for any total `k` and list `is`, if `is` has a pair of numbers that add to `k`,
then `find_sum` will return a pair of numbers `x` and `y` that add to `k`".
There's some nuance here. We hardly reference the `has_pair` property in this definition,
and for good reason. Our `has_pair` hypothesis only says that there is _at least one_
pair of numbers in `is` that meets our criteria. However, this pair need not be the only
one, nor does it need to be the one returned by `find_sum`! However, if we have many pairs,
we want to confirm that `find_sum` will find one of them. Finally, here is the proof.
I will not be able to go through it in detail in this post, but I did comment it to
make it easier to read:
{{< codelines "Coq" "aoc-coq/day1.v" 67 102 >}}
Coq seems happy with it, and so am I! The bug I mentioned earlier popped up on line 96.
I had accidentally made `find_sum` return `None` if it couldn't find a complement
for the `x` it encountered. This meant that it never recursed into the remaining
list `xs`, and thus, the pair was never found at all! It this became impossible
to prove that `find_some` will return `Some y`, and I had to double back
and check my definitions.
I hope you enjoyed this post! If you're interested to learn more about Coq, I strongly recommend
checking out [Software Foundations](https://softwarefoundations.cis.upenn.edu/), a series
of books on Coq written as comments in a Coq source file! In particular, check out
[Logical Foundations](https://softwarefoundations.cis.upenn.edu/lf-current/index.html)
for an introduction to using Coq. Thanks for reading!

Some files were not shown because too many files have changed in this diff Show More