-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathsearch_word_alignment.ipynb.broken
More file actions
147 lines (147 loc) · 5.32 KB
/
search_word_alignment.ipynb.broken
File metadata and controls
147 lines (147 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
},
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.2 64-bit ('venv': venv)"
},
"interpreter": {
"hash": "effe7c91cd0a74db529d06d3b051130f165eb4dab122b3e9455a73eecfba9cb1"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# Search - Word Alignment"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from vowpalwabbit import pyvw\n",
"\n",
"# the dataset is triples of E, A, F where A[i] = list of words E_i\n",
"# aligned to, or [] for null-aligned\n",
"my_dataset = [\n",
" ( \"the blue house\".split(),\n",
" ([0], [2], [1]),\n",
" \"la maison bleue\".split() ),\n",
" ( \"the house\".split(),\n",
" ([0], [1]),\n",
" \"la maison\".split() ),\n",
" ( \"the flower\".split(),\n",
" ([0], [1]),\n",
" \"la fleur\".split() )\n",
" ]\n",
"\n",
"my_dataset2 = [\n",
" ( \"mary did not slap the green witch\".split(),\n",
" ([0], [], [1],[2,3,4],[6],[8], [7]),\n",
" \"maria no dio una bofetada a la bruja verde\".split() ) ]\n",
" # 0 1 2 3 4 5 6 7 8\n",
"\n",
"\n",
"def alignmentError(true, sys):\n",
" t = set(true)\n",
" s = set(sys)\n",
" if len(t | s) == 0: return 0.\n",
" return 1. - float(len(t & s)) / float(len(t | s))\n",
"\n",
"\n",
"class WordAligner(pyvw.SearchTask):\n",
" def __init__(self, vw, sch, num_actions):\n",
" pyvw.SearchTask.__init__(self, vw, sch, num_actions)\n",
" sch.set_options( sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES )\n",
"\n",
" def makeExample(self, E, F, i, j0, l):\n",
" f = 'Null' if j0 is None else [ F[j0+k] for k in range(l+1) ]\n",
" ex = self.vw.example( { 'e': E[i],\n",
" 'f': f,\n",
" 'p': '_'.join(f),\n",
" 'l': str(l),\n",
" 'o': [str(i-j0), str(i-j0-l)] if j0 is not None else [] },\n",
" labelType = self.vw.lCostSensitive )\n",
" lab = 'Null' if j0 is None else str(j0+l)\n",
" ex.set_label_string(lab + ':0')\n",
" return ex\n",
"\n",
" def _run(self, alignedSentence):\n",
" E,A,F = alignedSentence\n",
"\n",
" # for each E word, we pick a F span\n",
" covered = {} # which F words have been covered so far?\n",
" output = []\n",
"\n",
" for i in range(len(E)):\n",
" examples = [] # contains vw examples\n",
" spans = [] # contains triples (alignment error, index in examples, [range])\n",
"\n",
" # empty span:\n",
" examples.append( self.makeExample(E, F, i, None, None) )\n",
" spans.append( (alignmentError(A[i], []), 0, []) )\n",
"\n",
" # non-empty spans\n",
" for j0 in range(len(F)):\n",
" for l in range(3): # max phrase length of 3\n",
" if j0+l >= len(F):\n",
" break\n",
" if j0+l in covered:\n",
" break\n",
"0\n",
" id = len(examples)\n",
" examples.append( self.makeExample(E, F, i, j0, l) )\n",
" spans.append( (alignmentError(A[i], list(range(j0,j0+l+1))), id, list(range(j0,j0+l+1))) )\n",
"\n",
" sortedSpans = []\n",
" for s in spans: sortedSpans.append(s)\n",
" sortedSpans.sort()\n",
" oracle = []\n",
" for id in range(len(sortedSpans)):\n",
" if sortedSpans[id][0] > sortedSpans[0][0]: break\n",
" oracle.append( sortedSpans[id][1] )\n",
"\n",
" pred = self.sch.predict(examples = examples,\n",
" my_tag = i+1,\n",
" oracle = oracle,\n",
" condition = [ (i, 'p'), (i-1, 'q') ] )\n",
"\n",
" self.vw.finish_example(examples)\n",
"\n",
" output.append( spans[pred][2] )\n",
" for j in spans[pred][2]:\n",
" covered[j] = True\n",
"\n",
" return output\n",
"\n",
"\n",
"print('training LDF')\n",
"vw = pyvw.Workspace(\"--search 0 --csoaa_ldf m --search_task hook --ring_size 1024 --quiet -q ef -q ep\")\n",
"task = vw.init_search_task(WordAligner)\n",
"for p in range(10):\n",
" task.learn(my_dataset)\n",
"print('====== test ======')\n",
"print(task.predict( (\"the blue flower\".split(), ([],[],[]), \"la fleur bleue\".split()) ))\n",
"print('should have printed [[0], [2], [1]]')\n"
]
}
]
}