Skip to content

Commit 0fc4910

Browse files
klauslermemfrob
authored andcommitted
[flang] COMPLEX folding
Original-commit: flang-compiler/f18@6f1ef45 Reviewed-on: flang-compiler/f18#162 Tree-same-pre-rewrite: false
1 parent a5eb4f9 commit 0fc4910

File tree

8 files changed

+230
-84
lines changed

8 files changed

+230
-84
lines changed

flang/lib/evaluate/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ template<typename A> struct ValueWithRealFlags {
8282

8383
ENUM_CLASS(Rounding, TiesToEven, ToZero, Down, Up, TiesAwayFromZero)
8484

85+
static constexpr Rounding defaultRounding{Rounding::TiesToEven};
86+
8587
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
8688
constexpr bool IsHostLittleEndian{false};
8789
#elif __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__

flang/lib/evaluate/complex.cc

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,40 @@
1717
namespace Fortran::evaluate::value {
1818

1919
template<typename R>
20-
ValueWithRealFlags<Complex<R>> Complex<R>::Add(const Complex &that) const {
20+
ValueWithRealFlags<Complex<R>> Complex<R>::Add(
21+
const Complex &that, Rounding rounding) const {
2122
RealFlags flags;
22-
Part reSum{re_.Add(that.re_).AccumulateFlags(flags)};
23-
Part imSum{im_.Add(that.im_).AccumulateFlags(flags)};
23+
Part reSum{re_.Add(that.re_, rounding).AccumulateFlags(flags)};
24+
Part imSum{im_.Add(that.im_, rounding).AccumulateFlags(flags)};
2425
return {Complex{reSum, imSum}, flags};
2526
}
2627

2728
template<typename R>
28-
ValueWithRealFlags<Complex<R>> Complex<R>::Subtract(const Complex &that) const {
29+
ValueWithRealFlags<Complex<R>> Complex<R>::Subtract(
30+
const Complex &that, Rounding rounding) const {
2931
RealFlags flags;
30-
Part reDiff{re_.Subtract(that.re_).AccumulateFlags(flags)};
31-
Part imDiff{im_.Subtract(that.im_).AccumulateFlags(flags)};
32+
Part reDiff{re_.Subtract(that.re_, rounding).AccumulateFlags(flags)};
33+
Part imDiff{im_.Subtract(that.im_, rounding).AccumulateFlags(flags)};
3234
return {Complex{reDiff, imDiff}, flags};
3335
}
3436

3537
template<typename R>
36-
ValueWithRealFlags<Complex<R>> Complex<R>::Multiply(const Complex &that) const {
38+
ValueWithRealFlags<Complex<R>> Complex<R>::Multiply(
39+
const Complex &that, Rounding rounding) const {
3740
// (a + ib)*(c + id) -> ac - bd + i(ad + bc)
3841
RealFlags flags;
39-
Part ac{re_.Multiply(that.re_).AccumulateFlags(flags)};
40-
Part bd{im_.Multiply(that.im_).AccumulateFlags(flags)};
41-
Part ad{re_.Multiply(that.im_).AccumulateFlags(flags)};
42-
Part bc{im_.Multiply(that.re_).AccumulateFlags(flags)};
43-
Part acbd{ac.Subtract(bd).AccumulateFlags(flags)};
44-
Part adbc{ad.Add(bc).AccumulateFlags(flags)};
42+
Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
43+
Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
44+
Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
45+
Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
46+
Part acbd{ac.Subtract(bd, rounding).AccumulateFlags(flags)};
47+
Part adbc{ad.Add(bc, rounding).AccumulateFlags(flags)};
4548
return {Complex{acbd, adbc}, flags};
4649
}
4750

4851
template<typename R>
49-
ValueWithRealFlags<Complex<R>> Complex<R>::Divide(const Complex &that) const {
52+
ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
53+
const Complex &that, Rounding rounding) const {
5054
// (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)]
5155
// -> [ac+bd+i(bc-ad)] / (cc+dd)
5256
// -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd))
@@ -55,30 +59,30 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide(const Complex &that) const {
5559
RealFlags flags;
5660
bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less};
5761
if (cGEd) {
58-
scale = that.im_.Divide(that.re_).AccumulateFlags(flags);
62+
scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags);
5963
} else {
60-
scale = that.re_.Divide(that.im_).AccumulateFlags(flags);
64+
scale = that.re_.Divide(that.im_, rounding).AccumulateFlags(flags);
6165
}
6266
Part den;
6367
if (cGEd) {
64-
Part dS{scale.Multiply(that.im_).AccumulateFlags(flags)};
65-
den = dS.Add(that.re_).AccumulateFlags(flags);
68+
Part dS{scale.Multiply(that.im_, rounding).AccumulateFlags(flags)};
69+
den = dS.Add(that.re_, rounding).AccumulateFlags(flags);
6670
} else {
67-
Part cS{scale.Multiply(that.re_).AccumulateFlags(flags)};
68-
den = cS.Add(that.im_).AccumulateFlags(flags);
71+
Part cS{scale.Multiply(that.re_, rounding).AccumulateFlags(flags)};
72+
den = cS.Add(that.im_, rounding).AccumulateFlags(flags);
6973
}
70-
Part aS{scale.Multiply(re_).AccumulateFlags(flags)};
71-
Part bS{scale.Multiply(im_).AccumulateFlags(flags)};
74+
Part aS{scale.Multiply(re_, rounding).AccumulateFlags(flags)};
75+
Part bS{scale.Multiply(im_, rounding).AccumulateFlags(flags)};
7276
Part re1, im1;
7377
if (cGEd) {
74-
re1 = re_.Add(bS).AccumulateFlags(flags);
75-
im1 = im_.Subtract(aS).AccumulateFlags(flags);
78+
re1 = re_.Add(bS, rounding).AccumulateFlags(flags);
79+
im1 = im_.Subtract(aS, rounding).AccumulateFlags(flags);
7680
} else {
77-
re1 = aS.Add(im_).AccumulateFlags(flags);
78-
im1 = bS.Subtract(re_).AccumulateFlags(flags);
81+
re1 = aS.Add(im_, rounding).AccumulateFlags(flags);
82+
im1 = bS.Subtract(re_, rounding).AccumulateFlags(flags);
7983
}
80-
Part re{re1.Divide(den).AccumulateFlags(flags)};
81-
Part im{im1.Divide(den).AccumulateFlags(flags)};
84+
Part re{re1.Divide(den, rounding).AccumulateFlags(flags)};
85+
Part im{im1.Divide(den, rounding).AccumulateFlags(flags)};
8286
return {Complex{re, im}, flags};
8387
}
8488

flang/lib/evaluate/complex.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ template<typename REAL_TYPE> class Complex {
2929
constexpr Complex(const Complex &) = default;
3030
constexpr Complex(const Part &r, const Part &i) : re_{r}, im_{i} {}
3131
explicit constexpr Complex(const Part &r) : re_{r} {}
32+
constexpr Complex &operator=(const Complex &) = default;
33+
constexpr Complex &operator=(Complex &&) = default;
3234

3335
constexpr const Part &REAL() const { return re_; }
3436
constexpr const Part &AIMAG() const { return im_; }
@@ -40,10 +42,43 @@ template<typename REAL_TYPE> class Complex {
4042
im_.Compare(that.im_) == Relation::Equal;
4143
}
4244

43-
ValueWithRealFlags<Complex> Add(const Complex &) const;
44-
ValueWithRealFlags<Complex> Subtract(const Complex &) const;
45-
ValueWithRealFlags<Complex> Multiply(const Complex &) const;
46-
ValueWithRealFlags<Complex> Divide(const Complex &) const;
45+
constexpr bool IsZero() const { return re_.IsZero() || im_.IsZero(); }
46+
47+
constexpr bool IsInfinite() const {
48+
return re_.IsInfinite() || im_.IsInfinite();
49+
}
50+
51+
constexpr bool IsNotANumber() const {
52+
return re_.IsNotANumber() || im_.IsNotANumber();
53+
}
54+
55+
constexpr bool IsSignalingNaN() const {
56+
return re_.IsSignalingNaN() || im_.IsSignalingNaN();
57+
}
58+
59+
template<typename INT>
60+
static ValueWithRealFlags<Complex> FromInteger(
61+
const INT &n, Rounding rounding = defaultRounding) {
62+
ValueWithRealFlags<Complex> result;
63+
result.value.re_ =
64+
Part::FromInteger(n, rounding).AccumulateFlags(result.flags);
65+
return result;
66+
}
67+
68+
ValueWithRealFlags<Complex> Add(
69+
const Complex &, Rounding rounding = defaultRounding) const;
70+
ValueWithRealFlags<Complex> Subtract(
71+
const Complex &, Rounding rounding = defaultRounding) const;
72+
ValueWithRealFlags<Complex> Multiply(
73+
const Complex &, Rounding rounding = defaultRounding) const;
74+
ValueWithRealFlags<Complex> Divide(
75+
const Complex &, Rounding rounding = defaultRounding) const;
76+
77+
constexpr Complex FlushDenormalToZero() const {
78+
return {re_.FlushDenormalToZero(), im_.FlushDenormalToZero()};
79+
}
80+
81+
static constexpr Complex NaN() { return {Part::NaN(), Part::NaN()}; }
4782

4883
std::string DumpHexadecimal() const;
4984
// TODO: (C)ABS once Real::HYPOT is done

0 commit comments

Comments
 (0)