MYF

HDU 5769 Substring

题目链接

HDU 5769

题目类型:后缀数组变形

题目来源:2016年多校Round4

题目分析

题目大意

给出一个字符ch以及一个字符串s,问s中有多少种含字符ch的连续子串。

解析

从蓝书(§3.4)中好好学习了一下后缀数组,讲的还是比较详细的,可惜似乎书上的代码有问题,不过从这节领悟一下精神还是不错的。

对于一个字符串来讲,共计有∑(len - sa[i] - height[i])个子串,len - sa[i] - height[i]表达的意义是:从sa[i]开始的后缀串,除去之前已经贡献的串,还能贡献的该后缀串的前缀串的个数。我们需要筛选出含有字符x的串,那么就要对后面减去的这堆玩意更新一下,因为起码要包含一个吧,所以要减去max(sa[i]+height[i], nxt[i])。理解了上面这句话,这题就不难做了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <set>
#include <map>
#include <stack>
#include <cmath>
#include <queue>
#include <cstdio>
#include <string>
#include <vector>
#include <iomanip>
#include <bitset>
#include <cstring>
#include <iostream>
#include <deque>
#include <algorithm>
#define Memset(a,val) memset(a,val,sizeof(a))
#define PI acos(-1)
#define PB push_back
#define MP make_pair
#define rt(n) (i == n ? '\n' : ' ')
#define hi printf("Hi----------\n")
#define IN freopen("input.txt","r",stdin);
#define OUT freopen("output.txt","w",stdout);
#define debug(x) cout<<"Debug : ---"<<x<<"---"<<endl;
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
typedef pair<int,int> PII;
typedef long long ll;
const int maxn=100000+5;
const int mod=1000000007;
const int INF=0x3f3f3f3f;
const double eps=1e-8;
/* start 后缀数组 start */
char s[maxn],ch[2];
/*rk的有效信息从0到n-1
sa,ht的有效信息从1到n
在传参的时候要多传一位*/

int rk[maxn], sa[maxn], ht[maxn], wa[maxn], wb[maxn], wx[maxn], wv[maxn];

bool isq(int *r, int a, int b, int len) {
return r[a] == r[b] && r[a + len] == r[b + len];
}

bool isEqual(int *r, int a, int b, int len) {
return r[a] == r[b] && r[a + len] == r[b + len];
}
// r数组有效信息为0~n-1,m为(数组中最大值的上界+1)
void getSa(char r[], int n, int m) {
int i, j, p, *t, *x = wa, *y = wb;
for (i = 0; i < m; ++i)
wx[i] = 0;
for (i = 0; i < n; ++i)
++wx[x[i] = r[i]];
for (i = 1; i < m; ++i)
wx[i] += wx[i - 1];
for (i = n - 1; i >= 0; --i)
sa[--wx[x[i]]] = i;
for (j = 1, p = 0; p < n; j <<= 1, m = p) {
for (p = 0, i = n - j; i < n; ++i)
y[p++] = i;
for (i = 0; i < n; ++i)
sa[i] >= j ? y[p++] = sa[i] - j : 0;
for (i = 0; i < m; ++i)
wx[i] = 0;
for (i = 0; i < n; ++i)
++wx[wv[i] = x[y[i]]];
for (i = 1; i < m; ++i)
wx[i] += wx[i - 1];
for (i = n - 1; i >= 0; --i)
sa[--wx[wv[i]]] = y[i];
p = 1, t = x, x = y, y = t;
x[sa[0]] = 0;
for (i = 1; i < n; ++i)
x[sa[i]] = isEqual(y, sa[i], sa[i - 1], j) ? p - 1 : p++;
}
}

void getHet(char r[], int n) {
int i, j, k = 0;
for (i = 1; i <= n; ++i)
rk[sa[i]] = i;
for (i = 0; i < n; ht[rk[i++]] = k) {
k = k > 0 ? k - 1 : 0;
j = sa[rk[i] - 1];
while (r[i + k] == r[j + k])
++k;
}
}
int nxt[maxn];
int main(){
int T;
scanf("%d",&T);
for (int cas=1; cas<=T; cas++) {
Memset(nxt, 0);
scanf("%s",ch);
scanf("%s",s);
int len = (int)strlen(s);
for (int i=0; i<len; i++) {
nxt[i]=len;
if (s[i]==ch[0]) {
nxt[i]=i;
int idx=i-1;
while (idx>=0&&s[idx]!=ch[0]) {
nxt[idx--]=i;
}
}
}
getSa(s, len+1, 'z'+1);
getHet(s, len);

ll ans = 0;
for (int i=1; i<=len; i++) {
ans+=len-max(sa[i]+ht[i], nxt[sa[i]]);
}
printf("Case #%d: %lld\n",cas,ans);
}
}